{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as Data\n",
    "torch.manual_seed(8) # for reproduce\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import gc\n",
    "import sys\n",
    "sys.setrecursionlimit(50000)\n",
    "import pickle\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "from tensorboardX import SummaryWriter\n",
    "torch.nn.Module.dump_patches = True\n",
    "import copy\n",
    "import pandas as pd\n",
    "#then import my own modules\n",
    "from AttentiveFP import Fingerprint, Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import matthews_corrcoef\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import r2_score\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import auc\n",
    "from sklearn.metrics import f1_score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from rdkit.Chem import rdMolDescriptors, MolSurf\n",
    "# from rdkit.Chem.Draw import SimilarityMaps\n",
    "from rdkit import Chem\n",
    "# from rdkit.Chem import AllChem\n",
    "from rdkit.Chem import QED\n",
    "%matplotlib inline\n",
    "from numpy.polynomial.polynomial import polyfit\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib\n",
    "from IPython.display import SVG, display\n",
    "import seaborn as sns; sns.set(color_codes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of all smiles:  8597\n",
      "not successfully processed smiles:  [F-][B+3]([F-])([F-])[F-].CC[N+]1(C)CCCC1\n",
      "not successfully processed smiles:  [NH4+].[NH4+].[Cl-][Pt++]([Cl-])([Cl-])[Cl-]\n",
      "not successfully processed smiles:  [Cl-][Pt]1([Cl-])[NH2+]CC[NH2+]1\n",
      "not successfully processed smiles:  [Na+].[Na+].F[Si--](F)(F)(F)(F)F\n",
      "not successfully processed smiles:  [NH4+].[NH4+].F[Si--](F)(F)(F)(F)F\n",
      "not successfully processed smiles:  O.O.O.O.O=C1O[Mg]2(OC(=O)C3=CC=CC=C3O2)OC2=CC=CC=C12\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "not successfully processed smiles:  FAIL\n",
      "number of successfully processed smiles:  8576\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU8AAAC/CAYAAAB+KF5fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEqVJREFUeJzt3X9M1df9x/HXvYJSkQIzt/iNuKq93ptSK3SritNohuK6RUuwTVdNttG52C0ldVl1uhmMS8xmkEgWiLHWmDXptq5ZOgklFbVuuqllRqNpJ+UCNo3RiKj8EHBX4N7vHwvXXuH+OoV7L97nIyGRc9733MPn433xuefzuR8sXq/XKwBARKyxngAAjEeEJwAYIDwBwADhCQAGCE8AMEB4AoABwhMADBCeAGCA8AQAA4QnABggPAHAAOEJAAYITwAwkBTrCURbR0evPJ6RbyQ1deoU3brVE+UZ4UHsh/iQKPvBarUoMzM14sclXHh6PN6A4TnUj9hjP8QH9kNgvG0HAAOEJwAYIDwBwADhCQAGEu6EUbQNeCR3/0DA/knJSUriVxgw7hCeY8zdP6CzjW0B++c/maWkSewGYLzhmAcADBCeAGAgZHieOXNGW7du1Xe+8x3l5uZq6dKlKi0tVVNT07DaU6dO6aWXXtK8efO0aNEibd++Xd3d3cPqent7tXPnTi1ZskTz5s3TmjVr9NFHH434/OGOCQDRFDI8//znP+vatWsqKSnRW2+9pa1bt+ratWt68cUXdeHCBV9dQ0ODNmzYoGnTpmnfvn3asmWLjh8/rg0bNsjj8fiNWVpaqtraWm3cuFFvvvmm7Ha7SktLdeLECb+6SMYEgGiyeL3eoJ+/unXrlqZOnerX1t3dreXLlys/P19VVVWSpBdffFEDAwN6//33ZbX+L5NPnTqlH//4x6qsrNT3vvc9SdKJEye0YcMGVVdXq7CwUJLk9Xq1bt06dXZ26sMPP/Q9T7hjRuLWrZ6AHzmz2dLU3n4n4jGD6XWHPmGUygkjP2OxHxC5RNkPVqtFU6dOifxxoQoeDE5JevTRR/X444/r+vXrkqS2tjZ98sknKioq8oWcJC1evFhZWVmqr6/3tR09elRpaWlavny5r81isai4uFiXL19WS0tLxGMCQLQZnTC6ffu2mpubNWfOHEmSy+WSJN/3X+ZwONTc3Oz7vrm5WXa73S8QJcnpdPqNFcmYABBtEYen1+tVWVmZPB6P1q9fL0nq7OyUJKWnpw+rT09P9/UP1Qaq+/JYkYwJANEW8WJbeXm5jh07pt/97nd64okn/PosFsuIj3mwPVBdJLXBxggm1NqGzZZmNG4g3tt9SpuSErB/8uRJsn1t8qg+58NgtPcDzLAfAosoPCsrK3Xw4EFt27ZNa9as8bVnZGRI0ohHg11dXX5HjxkZGQHrpPtHmpGMGYlonzDqcw/oTs9/A/f3udU+ODiqzzneJcqJiniXKPthzE4YDfn973+vffv2afPmzfrhD3/o1ze0LjnSOqTL5fJbt7Tb7WptbR12qdHQGqfD4Yh4TACItrDCs7q6Wnv37tXGjRv1k5/8ZFj/tGnTNHfuXNXW1vqF4pkzZ9TW1qaVK1f62goLC9Xd3a3jx4/7jXHo0CHNmjVLdrs94jEBINpCvm0/ePCgqqqq9O1vf1vf+ta3/C6MnzhxonJyciRJmzZt0vr16/WLX/xC3//+99XW1qaKigrl5ubqueee8z1m2bJlWrhwobZt26bOzk5lZ2fr0KFDOnfunPbu3ev33OGOCQDRFvIi+R/84Af697//PWLf9OnT/Y4gT548qaqqKn322WdKTU3VihUrtHnz5mHrkz09PdqzZ4/q6+vV3d0tu92u1157TStWrBj2HOGOGS4uko9/ibLWFu8SZT+YrnmGDM+HDeEZ/xLlRRvvEmU/jPkJIwDAfYQnABggPAHAAOEJAAYITwAwQHgCgAHCEwAMEJ4AYIDwBAADhCcAGCA8AcAA4QkABghPADBAeAKAAcITAAwQngBggPAEAAOEJwAYIDwBwADhCQAGCE8AMEB4AoABwhMADBCeAGCA8AQAA4QnABggPAHAAOEJAAYITwAwQHgCgAHCEwAMJMV6AuPZgEdy9w8ErfF4ozQZAFFFeH4F7v4BnW1sC1qT67BFaTYAoom37QBggPAEAAOEJwAYIDwBwADhCQAGCE8AMEB4AoABrvOMMYvVol538AvtJyUnKYlfc0BcITxjzN0/qIuu9qA185/MUtIkdhUQTzieAQADhCcAGCA8AcAA4QkABghPADBAeAKAAcITAAwQngBggPAEAAOEJwAYCCs8r1+/rp07d2rt2rV65pln5HQ61dDQMGJtbW2tnn/+eT399NNaunSpKioq5Ha7h9XdvHlTW7Zs0cKFC5WXl6d169bp/PnzX2nMh9XQ598DfQ14Yj1DIPGEFZ5ffPGF6urqNHnyZOXn5wesq6mp0aZNm/SNb3xDb731ll599VX98Y9/1NatW/3q3G63SkpKdPbsWZWVlam6ulqpqakqKSnRpUuXjMZ8mLn7B3W2sS3gV6i/4Alg9IV1t4n58+frzJkzkqRjx47p+PHjw2oGBwe1e/duFRQUaMeOHZKk/Px8JScnq6ysTCUlJcrNzZUk/fWvf1Vzc7Pef/99PfXUU5KkBQsW6Lvf/a727NmjAwcORDwmAERTWEeeVmvosgsXLqi9vV3FxcV+7atXr1ZycrLq6+t9bceOHZPD4fAFpyRNnDhRq1at0unTp9XT0xPxmAAQTaN2wqi5uVmSNGfOHL/2Rx55RDNmzPD1D9U6HI5hYzidTg0ODury5csRjwkA0TRq4dnZ2SlJSk9PH9aXnp7u6x+qDVQnSR0dHRGPCQDRNOp32LVYLGG1B6qLpDbYGIFMnTolaL/Nlhb2WN7bfUqbkhK0Jjk5KWhNqP5waiZPniTb1yYHn+w4E8l+wNhhPwQ2auGZkZEh6X9Hi5mZmX59XV1dys7O9qsd6aixq6vLb6xIxgzXrVs98ni8I/bZbGlqb78T9lh97gHd6flv0Jr+/uA1ofrDqenrc6t9cDD4ZMeRSPcDxkai7Aer1RLyoGrEx43WBOx2uyQNW4e8e/eurly54rduabfb5XK5ho3R1NSkCRMmaPbs2RGPCQDRNGrhmZeXJ5vNppqaGr/2Dz74QP39/Vq5cqWvrbCwUC6XS42Njb62e/fuqa6uTosWLdKUKVMiHhMAomnCjqELKEM4fPiwWlpadPHiRZ0/f17Z2dm6ffu2rl69qpkzZ8pqtSozM1P79+9XR0eHUlJSdPLkSZWXl6ugoECvvPKKbyyn06kjR46otrZWNptNN27c0K5du9TU1KSKigo99thjkhTRmOG6e/eevCO/a1dq6iT19d0Le6z+QY+u3ewNWjNtaqrabvUZ94dTk52VpnsDHvUPBv6yWq2yRr5EHBOR7geMjUTZDxaLRZMnT4z4cWGveW7cuNHv+6qqKknS9OnTfRfNFxcXy2q16sCBA3rvvfeUmZmpl19+Wa+//rrfYydNmqS3335b5eXl2rFjh9xut3JycnTw4EHNnTvXrzbcMRMZf4ETiL6wX01NTU1h1RUVFamoqChknc1m0+7du0d1TACIFu6qBAAGCE8AMEB4AoABwhMADBCeAGCA8AQAA4QnABggPAHAAOEJAAYITwAwQHgCgAHCEwAMEJ4AYIDwBAADhCcAGCA8AcAA4QkABghPADBAeAKAAcITAAwQngBggPAEAAOEJwAYIDwBwADhCQAGkmI9AUSHxWpRr3sgaM2k5CQl8esUCAvhmSDc/YO66GoPWjP/ySwlTeK/BBAOXilBDHgkd3/gozWPN4qTARBXCM8g3P0DOtvYFrA/12GL4mwAxBNWuADAAOEJAAYITwAwQHgCgAHCEwAMEJ4AYIDwBAADhCcAGCA8AcAA4QkABghPADDAZ9vhE+q2ddyyDriP8IRPqNvWccs64D6OIwDAAOEJAAYITwAwQHgCgAHCEwAMcOoUYeMvcAL3EZ4IG3+BE7iPYwQAMEB4AoABwhMADIyLxane3l5VVlbq8OHD6u7ult1u12uvvably5fHemp4AJ+PR6IYF+FZWlqqS5cuadOmTcrOztbf/vY3lZaWat++fVq2bFmsp4cv4fPxSBRx/7/4xIkTOn36tKqrq1VYWChJys/P15UrV7Rr1y7Cc5wJ53KnlL57UZoNYC7uw/Po0aNKS0vze4tusVhUXFyssrIytbS0yG63x3CGiEQ4lzstzsvWvSABm5yUpP6B4AEcTg1LCPgq4j48m5ubZbfbZbX6/y93Op2SJJfLFVF4Wq2WsPuTJlg1OSU5YG2o/vE0RjzNddDj0X8+vx2w/8lZX1NjkP5wa3IdNg0OeAP2T0yaoAkhwnXQI90bGPxKY8SzUK+Xh4Hpzxj34dnZ2amZM2cOa09PT/f1RyIzMzVo/9SpU/y+z/6/9KD1s7MzQz5nqJp4GSNazxPOGDOyHh3zeSC0B18PuG9c/E60WAL/ZgjWBwBjJe7DMyMjY8Sjy66uLkn3j0ABIJriPjztdrtaW1vl8Xj82l0ulyTJ4XDEYloAElzch2dhYaG6u7t1/Phxv/ZDhw5p1qxZnGkHEBNxf8Jo2bJlWrhwobZt26bOzk5lZ2fr0KFDOnfunPbu3Rvr6QFIUBav1xv4Wo040dPToz179qi+vt7v45krVqyI9dQAJKhxEZ4AEG/ifs0TAOIR4QkABhI+PHt7e7Vz504tWbJE8+bN05o1a/TRRx/FeloPrYaGBjmdzhG/Wltb/WpPnTqll156SfPmzdOiRYu0fft2dXd3x2jm49f169e1c+dOrV27Vs8884ycTqcaGhpGrK2trdXzzz+vp59+WkuXLlVFRYXcbvewups3b2rLli1auHCh8vLytG7dOp0/f36sf5S4Evdn28cat7uLjU2bNmn+/Pl+bdnZ2b5/NzQ0aMOGDVq+fLl+/vOf68aNG6qoqJDL5dKf/vSnYfc6QGBffPGF6urqlJOTo/z8/GGX/Q2pqanRL3/5S61du1a//vWv1draqoqKCl29elWVlZW+OrfbrZKSEvX19amsrEwZGRl6++23VVJSonfffVc5OTnR+tFiy5vA/vGPf3gdDof3yJEjvjaPx+N9+eWXvc8991wMZ/bw+vjjj70Oh8N79OjRoHUvvPCCt6ioyDs4OOhr+9e//uV1OBzeurq6sZ7mQ+XL2/Do0aNeh8Ph/fjjj/1qBgYGvIsXL/b+9Kc/9Wv/y1/+4nU4HN4LFy742t555x2vw+Hwfvrpp742t9vtLSgo8K5fv36Mfor4k9C/voPd7u7y5ctqaWmJ4ewSV1tbmz755BMVFRX5HWEuXrxYWVlZqq+vj+Hsxp9wjtIvXLig9vZ2FRcX+7WvXr1aycnJftv82LFjcjgceuqpp3xtEydO1KpVq3T69Gn19PSM3uTjWEKHZzi3u8PY2L59u3JycvTNb35Tr776qj799FNf39B2nzNnzrDHORwONTc3R22eiWJomz64zR955BHNmDHDb5s3NzeP+LFop9OpwcFBXb58eWwnGycSes1ztG93h9DS0tL0ox/9SAsWLFBGRoZaW1u1f/9+rV27Vu+8845yc3N9232km76kp6fr0qVL0Z72Qy/UNv/ya6GzszNgnSR1dHSM0SzjS0KHp8Tt7qItJyfH74TCs88+q4KCAq1atUqVlZX6wx/+4OsLtP3ZL2Mn3G3O6ybB37Zzu7v4YLPZtGTJEl28eFHS//aLNPKRf1dXF/tlDESyzUO9bobGetgldHhyu7v48eV9MLTuNtLapsvlGnEtFF/N0N3JHtzmd+/e1ZUrV/y2ud1uH/F8QFNTkyZMmKDZs2eP7WTjREKHJ7e7iw/t7e06ffq08vLyJEnTpk3T3LlzVVtb6xeqZ86cUVtbm1auXBmrqT608vLyZLPZVFNT49f+wQcfqL+/32+bFxYWyuVyqbGx0dd279491dXVadGiRZoyJTH+dMeEHTt27Ij1JGLl8ccf19mzZ/Xee+8pMzNT3d3dqq6u1t///nf99re/1axZs2I9xYfOG2+8ocbGRt25c0c3b97UP//5T/3qV7/SnTt3tHv3bmVlZUmSvv71r+vgwYNqaWlRenq6zp07p9/85jeaM2eOtm7dykXyETp8+LBaWlp08eJFnT9/XtnZ2bp9+7auXr2qmTNnymq1KjMzU/v371dHR4dSUlJ08uRJlZeXq6CgQK+88opvLKfTqSNHjqi2tlY2m003btzQrl271NTUpIqKCj322GMx/EmjJ+HvqsTt7qJr//79qqur09WrV3X37l1lZGRowYIF+tnPfjZsmeTkyZOqqqrSZ599ptTUVK1YsUKbN29mzdPA0OV3D5o+fbrfO6+amhodOHBAn3/+uTIzM7V69Wq9/vrrSklJ8Xtce3u7ysvLdeLECbndbuXk5OiNN97Qs88+O6Y/RzxJ+PAEABO89wEAA4QnABggPAHAAOEJAAYITwAwQHgCgAHCEwAMEJ4AYIDwBAAD/w/5l+xWCz/JlAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_name = 'toxcast'\n",
    "tasks = [\n",
    "'ACEA_T47D_80hr_Negative','ACEA_T47D_80hr_Positive','APR_HepG2_CellCycleArrest_24h_dn','APR_HepG2_CellCycleArrest_24h_up','APR_HepG2_CellCycleArrest_72h_dn','APR_HepG2_CellLoss_24h_dn','APR_HepG2_CellLoss_72h_dn','APR_HepG2_MicrotubuleCSK_24h_dn','APR_HepG2_MicrotubuleCSK_24h_up','APR_HepG2_MicrotubuleCSK_72h_dn','APR_HepG2_MicrotubuleCSK_72h_up','APR_HepG2_MitoMass_24h_dn','APR_HepG2_MitoMass_24h_up','APR_HepG2_MitoMass_72h_dn','APR_HepG2_MitoMass_72h_up','APR_HepG2_MitoMembPot_1h_dn','APR_HepG2_MitoMembPot_24h_dn','APR_HepG2_MitoMembPot_72h_dn','APR_HepG2_MitoticArrest_24h_up','APR_HepG2_MitoticArrest_72h_up','APR_HepG2_NuclearSize_24h_dn','APR_HepG2_NuclearSize_72h_dn','APR_HepG2_NuclearSize_72h_up','APR_HepG2_OxidativeStress_24h_up','APR_HepG2_OxidativeStress_72h_up','APR_HepG2_StressKinase_1h_up','APR_HepG2_StressKinase_24h_up','APR_HepG2_StressKinase_72h_up','APR_HepG2_p53Act_24h_up','APR_HepG2_p53Act_72h_up','APR_Hepat_Apoptosis_24hr_up','APR_Hepat_Apoptosis_48hr_up','APR_Hepat_CellLoss_24hr_dn','APR_Hepat_CellLoss_48hr_dn','APR_Hepat_DNADamage_24hr_up','APR_Hepat_DNADamage_48hr_up','APR_Hepat_DNATexture_24hr_up','APR_Hepat_DNATexture_48hr_up','APR_Hepat_MitoFxnI_1hr_dn','APR_Hepat_MitoFxnI_24hr_dn','APR_Hepat_MitoFxnI_48hr_dn','APR_Hepat_NuclearSize_24hr_dn','APR_Hepat_NuclearSize_48hr_dn','APR_Hepat_Steatosis_24hr_up','APR_Hepat_Steatosis_48hr_up','ATG_AP_1_CIS_dn','ATG_AP_1_CIS_up','ATG_AP_2_CIS_dn','ATG_AP_2_CIS_up','ATG_AR_TRANS_dn','ATG_AR_TRANS_up','ATG_Ahr_CIS_dn','ATG_Ahr_CIS_up','ATG_BRE_CIS_dn','ATG_BRE_CIS_up','ATG_CAR_TRANS_dn','ATG_CAR_TRANS_up','ATG_CMV_CIS_dn','ATG_CMV_CIS_up','ATG_CRE_CIS_dn','ATG_CRE_CIS_up','ATG_C_EBP_CIS_dn','ATG_C_EBP_CIS_up','ATG_DR4_LXR_CIS_dn','ATG_DR4_LXR_CIS_up','ATG_DR5_CIS_dn','ATG_DR5_CIS_up','ATG_E2F_CIS_dn','ATG_E2F_CIS_up','ATG_EGR_CIS_up','ATG_ERE_CIS_dn','ATG_ERE_CIS_up','ATG_ERRa_TRANS_dn','ATG_ERRg_TRANS_dn','ATG_ERRg_TRANS_up','ATG_ERa_TRANS_up','ATG_E_Box_CIS_dn','ATG_E_Box_CIS_up','ATG_Ets_CIS_dn','ATG_Ets_CIS_up','ATG_FXR_TRANS_up','ATG_FoxA2_CIS_dn','ATG_FoxA2_CIS_up','ATG_FoxO_CIS_dn','ATG_FoxO_CIS_up','ATG_GAL4_TRANS_dn','ATG_GATA_CIS_dn','ATG_GATA_CIS_up','ATG_GLI_CIS_dn','ATG_GLI_CIS_up','ATG_GRE_CIS_dn','ATG_GRE_CIS_up','ATG_GR_TRANS_dn','ATG_GR_TRANS_up','ATG_HIF1a_CIS_dn','ATG_HIF1a_CIS_up','ATG_HNF4a_TRANS_dn','ATG_HNF4a_TRANS_up','ATG_HNF6_CIS_dn','ATG_HNF6_CIS_up','ATG_HSE_CIS_dn','ATG_HSE_CIS_up','ATG_IR1_CIS_dn','ATG_IR1_CIS_up','ATG_ISRE_CIS_dn','ATG_ISRE_CIS_up','ATG_LXRa_TRANS_dn','ATG_LXRa_TRANS_up','ATG_LXRb_TRANS_dn','ATG_LXRb_TRANS_up','ATG_MRE_CIS_up','ATG_M_06_TRANS_up','ATG_M_19_CIS_dn','ATG_M_19_TRANS_dn','ATG_M_19_TRANS_up','ATG_M_32_CIS_dn','ATG_M_32_CIS_up','ATG_M_32_TRANS_dn','ATG_M_32_TRANS_up','ATG_M_61_TRANS_up','ATG_Myb_CIS_dn','ATG_Myb_CIS_up','ATG_Myc_CIS_dn','ATG_Myc_CIS_up','ATG_NFI_CIS_dn','ATG_NFI_CIS_up','ATG_NF_kB_CIS_dn','ATG_NF_kB_CIS_up','ATG_NRF1_CIS_dn','ATG_NRF1_CIS_up','ATG_NRF2_ARE_CIS_dn','ATG_NRF2_ARE_CIS_up','ATG_NURR1_TRANS_dn','ATG_NURR1_TRANS_up','ATG_Oct_MLP_CIS_dn','ATG_Oct_MLP_CIS_up','ATG_PBREM_CIS_dn','ATG_PBREM_CIS_up','ATG_PPARa_TRANS_dn','ATG_PPARa_TRANS_up','ATG_PPARd_TRANS_up','ATG_PPARg_TRANS_up','ATG_PPRE_CIS_dn','ATG_PPRE_CIS_up','ATG_PXRE_CIS_dn','ATG_PXRE_CIS_up','ATG_PXR_TRANS_dn','ATG_PXR_TRANS_up','ATG_Pax6_CIS_up','ATG_RARa_TRANS_dn','ATG_RARa_TRANS_up','ATG_RARb_TRANS_dn','ATG_RARb_TRANS_up','ATG_RARg_TRANS_dn','ATG_RARg_TRANS_up','ATG_RORE_CIS_dn','ATG_RORE_CIS_up','ATG_RORb_TRANS_dn','ATG_RORg_TRANS_dn','ATG_RORg_TRANS_up','ATG_RXRa_TRANS_dn','ATG_RXRa_TRANS_up','ATG_RXRb_TRANS_dn','ATG_RXRb_TRANS_up','ATG_SREBP_CIS_dn','ATG_SREBP_CIS_up','ATG_STAT3_CIS_dn','ATG_STAT3_CIS_up','ATG_Sox_CIS_dn','ATG_Sox_CIS_up','ATG_Sp1_CIS_dn','ATG_Sp1_CIS_up','ATG_TAL_CIS_dn','ATG_TAL_CIS_up','ATG_TA_CIS_dn','ATG_TA_CIS_up','ATG_TCF_b_cat_CIS_dn','ATG_TCF_b_cat_CIS_up','ATG_TGFb_CIS_dn','ATG_TGFb_CIS_up','ATG_THRa1_TRANS_dn','ATG_THRa1_TRANS_up','ATG_VDRE_CIS_dn','ATG_VDRE_CIS_up','ATG_VDR_TRANS_dn','ATG_VDR_TRANS_up','ATG_XTT_Cytotoxicity_up','ATG_Xbp1_CIS_dn','ATG_Xbp1_CIS_up','ATG_p53_CIS_dn','ATG_p53_CIS_up','BSK_3C_Eselectin_down','BSK_3C_HLADR_down','BSK_3C_ICAM1_down','BSK_3C_IL8_down','BSK_3C_MCP1_down','BSK_3C_MIG_down','BSK_3C_Proliferation_down','BSK_3C_SRB_down','BSK_3C_Thrombomodulin_down','BSK_3C_Thrombomodulin_up','BSK_3C_TissueFactor_down','BSK_3C_TissueFactor_up','BSK_3C_VCAM1_down','BSK_3C_Vis_down','BSK_3C_uPAR_down','BSK_4H_Eotaxin3_down','BSK_4H_MCP1_down','BSK_4H_Pselectin_down','BSK_4H_Pselectin_up','BSK_4H_SRB_down','BSK_4H_VCAM1_down','BSK_4H_VEGFRII_down','BSK_4H_uPAR_down','BSK_4H_uPAR_up','BSK_BE3C_HLADR_down','BSK_BE3C_IL1a_down','BSK_BE3C_IP10_down','BSK_BE3C_MIG_down','BSK_BE3C_MMP1_down','BSK_BE3C_MMP1_up','BSK_BE3C_PAI1_down','BSK_BE3C_SRB_down','BSK_BE3C_TGFb1_down','BSK_BE3C_tPA_down','BSK_BE3C_uPAR_down','BSK_BE3C_uPAR_up','BSK_BE3C_uPA_down','BSK_CASM3C_HLADR_down','BSK_CASM3C_IL6_down','BSK_CASM3C_IL6_up','BSK_CASM3C_IL8_down','BSK_CASM3C_LDLR_down','BSK_CASM3C_LDLR_up','BSK_CASM3C_MCP1_down','BSK_CASM3C_MCP1_up','BSK_CASM3C_MCSF_down','BSK_CASM3C_MCSF_up','BSK_CASM3C_MIG_down','BSK_CASM3C_Proliferation_down','BSK_CASM3C_Proliferation_up','BSK_CASM3C_SAA_down','BSK_CASM3C_SAA_up','BSK_CASM3C_SRB_down','BSK_CASM3C_Thrombomodulin_down','BSK_CASM3C_Thrombomodulin_up','BSK_CASM3C_TissueFactor_down','BSK_CASM3C_VCAM1_down','BSK_CASM3C_VCAM1_up','BSK_CASM3C_uPAR_down','BSK_CASM3C_uPAR_up','BSK_KF3CT_ICAM1_down','BSK_KF3CT_IL1a_down','BSK_KF3CT_IP10_down','BSK_KF3CT_IP10_up','BSK_KF3CT_MCP1_down','BSK_KF3CT_MCP1_up','BSK_KF3CT_MMP9_down','BSK_KF3CT_SRB_down','BSK_KF3CT_TGFb1_down','BSK_KF3CT_TIMP2_down','BSK_KF3CT_uPA_down','BSK_LPS_CD40_down','BSK_LPS_Eselectin_down','BSK_LPS_Eselectin_up','BSK_LPS_IL1a_down','BSK_LPS_IL1a_up','BSK_LPS_IL8_down','BSK_LPS_IL8_up','BSK_LPS_MCP1_down','BSK_LPS_MCSF_down','BSK_LPS_PGE2_down','BSK_LPS_PGE2_up','BSK_LPS_SRB_down','BSK_LPS_TNFa_down','BSK_LPS_TNFa_up','BSK_LPS_TissueFactor_down','BSK_LPS_TissueFactor_up','BSK_LPS_VCAM1_down','BSK_SAg_CD38_down','BSK_SAg_CD40_down','BSK_SAg_CD69_down','BSK_SAg_Eselectin_down','BSK_SAg_Eselectin_up','BSK_SAg_IL8_down','BSK_SAg_IL8_up','BSK_SAg_MCP1_down','BSK_SAg_MIG_down','BSK_SAg_PBMCCytotoxicity_down','BSK_SAg_PBMCCytotoxicity_up','BSK_SAg_Proliferation_down','BSK_SAg_SRB_down','BSK_hDFCGF_CollagenIII_down','BSK_hDFCGF_EGFR_down','BSK_hDFCGF_EGFR_up','BSK_hDFCGF_IL8_down','BSK_hDFCGF_IP10_down','BSK_hDFCGF_MCSF_down','BSK_hDFCGF_MIG_down','BSK_hDFCGF_MMP1_down','BSK_hDFCGF_MMP1_up','BSK_hDFCGF_PAI1_down','BSK_hDFCGF_Proliferation_down','BSK_hDFCGF_SRB_down','BSK_hDFCGF_TIMP1_down','BSK_hDFCGF_VCAM1_down','CEETOX_H295R_11DCORT_dn','CEETOX_H295R_ANDR_dn','CEETOX_H295R_CORTISOL_dn','CEETOX_H295R_DOC_dn','CEETOX_H295R_DOC_up','CEETOX_H295R_ESTRADIOL_dn','CEETOX_H295R_ESTRADIOL_up','CEETOX_H295R_ESTRONE_dn','CEETOX_H295R_ESTRONE_up','CEETOX_H295R_OHPREG_up','CEETOX_H295R_OHPROG_dn','CEETOX_H295R_OHPROG_up','CEETOX_H295R_PROG_up','CEETOX_H295R_TESTO_dn','CLD_ABCB1_48hr','CLD_ABCG2_48hr','CLD_CYP1A1_24hr','CLD_CYP1A1_48hr','CLD_CYP1A1_6hr','CLD_CYP1A2_24hr','CLD_CYP1A2_48hr','CLD_CYP1A2_6hr','CLD_CYP2B6_24hr','CLD_CYP2B6_48hr','CLD_CYP2B6_6hr','CLD_CYP3A4_24hr','CLD_CYP3A4_48hr','CLD_CYP3A4_6hr','CLD_GSTA2_48hr','CLD_SULT2A_24hr','CLD_SULT2A_48hr','CLD_UGT1A1_24hr','CLD_UGT1A1_48hr','NCCT_HEK293T_CellTiterGLO','NCCT_QuantiLum_inhib_2_dn','NCCT_QuantiLum_inhib_dn','NCCT_TPO_AUR_dn','NCCT_TPO_GUA_dn','NHEERL_ZF_144hpf_TERATOSCORE_up','NVS_ADME_hCYP19A1','NVS_ADME_hCYP1A1','NVS_ADME_hCYP1A2','NVS_ADME_hCYP2A6','NVS_ADME_hCYP2B6','NVS_ADME_hCYP2C19','NVS_ADME_hCYP2C9','NVS_ADME_hCYP2D6','NVS_ADME_hCYP3A4','NVS_ADME_hCYP4F12','NVS_ADME_rCYP2C12','NVS_ENZ_hAChE','NVS_ENZ_hAMPKa1','NVS_ENZ_hAurA','NVS_ENZ_hBACE','NVS_ENZ_hCASP5','NVS_ENZ_hCK1D','NVS_ENZ_hDUSP3','NVS_ENZ_hES','NVS_ENZ_hElastase','NVS_ENZ_hFGFR1','NVS_ENZ_hGSK3b','NVS_ENZ_hMMP1','NVS_ENZ_hMMP13','NVS_ENZ_hMMP2','NVS_ENZ_hMMP3','NVS_ENZ_hMMP7','NVS_ENZ_hMMP9','NVS_ENZ_hPDE10','NVS_ENZ_hPDE4A1','NVS_ENZ_hPDE5','NVS_ENZ_hPI3Ka','NVS_ENZ_hPTEN','NVS_ENZ_hPTPN11','NVS_ENZ_hPTPN12','NVS_ENZ_hPTPN13','NVS_ENZ_hPTPN9','NVS_ENZ_hPTPRC','NVS_ENZ_hSIRT1','NVS_ENZ_hSIRT2','NVS_ENZ_hTrkA','NVS_ENZ_hVEGFR2','NVS_ENZ_oCOX1','NVS_ENZ_oCOX2','NVS_ENZ_rAChE','NVS_ENZ_rCNOS','NVS_ENZ_rMAOAC','NVS_ENZ_rMAOAP','NVS_ENZ_rMAOBC','NVS_ENZ_rMAOBP','NVS_ENZ_rabI2C','NVS_GPCR_bAdoR_NonSelective','NVS_GPCR_bDR_NonSelective','NVS_GPCR_g5HT4','NVS_GPCR_gH2','NVS_GPCR_gLTB4','NVS_GPCR_gLTD4','NVS_GPCR_gMPeripheral_NonSelective','NVS_GPCR_gOpiateK','NVS_GPCR_h5HT2A','NVS_GPCR_h5HT5A','NVS_GPCR_h5HT6','NVS_GPCR_h5HT7','NVS_GPCR_hAT1','NVS_GPCR_hAdoRA1','NVS_GPCR_hAdoRA2a','NVS_GPCR_hAdra2A','NVS_GPCR_hAdra2C','NVS_GPCR_hAdrb1','NVS_GPCR_hAdrb2','NVS_GPCR_hAdrb3','NVS_GPCR_hDRD1','NVS_GPCR_hDRD2s','NVS_GPCR_hDRD4.4','NVS_GPCR_hH1','NVS_GPCR_hLTB4_BLT1','NVS_GPCR_hM1','NVS_GPCR_hM2','NVS_GPCR_hM3','NVS_GPCR_hM4','NVS_GPCR_hNK2','NVS_GPCR_hOpiate_D1','NVS_GPCR_hOpiate_mu','NVS_GPCR_hTXA2','NVS_GPCR_p5HT2C','NVS_GPCR_r5HT1_NonSelective','NVS_GPCR_r5HT_NonSelective','NVS_GPCR_rAdra1B','NVS_GPCR_rAdra1_NonSelective','NVS_GPCR_rAdra2_NonSelective','NVS_GPCR_rAdrb_NonSelective','NVS_GPCR_rNK1','NVS_GPCR_rNK3','NVS_GPCR_rOpiate_NonSelective','NVS_GPCR_rOpiate_NonSelectiveNa','NVS_GPCR_rSST','NVS_GPCR_rTRH','NVS_GPCR_rV1','NVS_GPCR_rabPAF','NVS_GPCR_rmAdra2B','NVS_IC_hKhERGCh','NVS_IC_rCaBTZCHL','NVS_IC_rCaDHPRCh_L','NVS_IC_rNaCh_site2','NVS_LGIC_bGABARa1','NVS_LGIC_h5HT3','NVS_LGIC_hNNR_NBungSens','NVS_LGIC_rGABAR_NonSelective','NVS_LGIC_rNNR_BungSens','NVS_MP_hPBR','NVS_MP_rPBR','NVS_NR_bER','NVS_NR_bPR','NVS_NR_cAR','NVS_NR_hAR','NVS_NR_hCAR_Antagonist','NVS_NR_hER','NVS_NR_hFXR_Agonist','NVS_NR_hFXR_Antagonist','NVS_NR_hGR','NVS_NR_hPPARa','NVS_NR_hPPARg','NVS_NR_hPR','NVS_NR_hPXR','NVS_NR_hRAR_Antagonist','NVS_NR_hRARa_Agonist','NVS_NR_hTRa_Antagonist','NVS_NR_mERa','NVS_NR_rAR','NVS_NR_rMR','NVS_OR_gSIGMA_NonSelective','NVS_TR_gDAT','NVS_TR_hAdoT','NVS_TR_hDAT','NVS_TR_hNET','NVS_TR_hSERT','NVS_TR_rNET','NVS_TR_rSERT','NVS_TR_rVMAT2','OT_AR_ARELUC_AG_1440','OT_AR_ARSRC1_0480','OT_AR_ARSRC1_0960','OT_ER_ERaERa_0480','OT_ER_ERaERa_1440','OT_ER_ERaERb_0480','OT_ER_ERaERb_1440','OT_ER_ERbERb_0480','OT_ER_ERbERb_1440','OT_ERa_EREGFP_0120','OT_ERa_EREGFP_0480','OT_FXR_FXRSRC1_0480','OT_FXR_FXRSRC1_1440','OT_NURR1_NURR1RXRa_0480','OT_NURR1_NURR1RXRa_1440','TOX21_ARE_BLA_Agonist_ch1','TOX21_ARE_BLA_Agonist_ch2','TOX21_ARE_BLA_agonist_ratio','TOX21_ARE_BLA_agonist_viability','TOX21_AR_BLA_Agonist_ch1','TOX21_AR_BLA_Agonist_ch2','TOX21_AR_BLA_Agonist_ratio','TOX21_AR_BLA_Antagonist_ch1','TOX21_AR_BLA_Antagonist_ch2','TOX21_AR_BLA_Antagonist_ratio','TOX21_AR_BLA_Antagonist_viability','TOX21_AR_LUC_MDAKB2_Agonist','TOX21_AR_LUC_MDAKB2_Antagonist','TOX21_AR_LUC_MDAKB2_Antagonist2','TOX21_AhR_LUC_Agonist','TOX21_Aromatase_Inhibition','TOX21_AutoFluor_HEK293_Cell_blue','TOX21_AutoFluor_HEK293_Media_blue','TOX21_AutoFluor_HEPG2_Cell_blue','TOX21_AutoFluor_HEPG2_Cell_green','TOX21_AutoFluor_HEPG2_Media_blue','TOX21_AutoFluor_HEPG2_Media_green','TOX21_ELG1_LUC_Agonist','TOX21_ERa_BLA_Agonist_ch1','TOX21_ERa_BLA_Agonist_ch2','TOX21_ERa_BLA_Agonist_ratio','TOX21_ERa_BLA_Antagonist_ch1','TOX21_ERa_BLA_Antagonist_ch2','TOX21_ERa_BLA_Antagonist_ratio','TOX21_ERa_BLA_Antagonist_viability','TOX21_ERa_LUC_BG1_Agonist','TOX21_ERa_LUC_BG1_Antagonist','TOX21_ESRE_BLA_ch1','TOX21_ESRE_BLA_ch2','TOX21_ESRE_BLA_ratio','TOX21_ESRE_BLA_viability','TOX21_FXR_BLA_Antagonist_ch1','TOX21_FXR_BLA_Antagonist_ch2','TOX21_FXR_BLA_agonist_ch2','TOX21_FXR_BLA_agonist_ratio','TOX21_FXR_BLA_antagonist_ratio','TOX21_FXR_BLA_antagonist_viability','TOX21_GR_BLA_Agonist_ch1','TOX21_GR_BLA_Agonist_ch2','TOX21_GR_BLA_Agonist_ratio','TOX21_GR_BLA_Antagonist_ch2','TOX21_GR_BLA_Antagonist_ratio','TOX21_GR_BLA_Antagonist_viability','TOX21_HSE_BLA_agonist_ch1','TOX21_HSE_BLA_agonist_ch2','TOX21_HSE_BLA_agonist_ratio','TOX21_HSE_BLA_agonist_viability','TOX21_MMP_ratio_down','TOX21_MMP_ratio_up','TOX21_MMP_viability','TOX21_NFkB_BLA_agonist_ch1','TOX21_NFkB_BLA_agonist_ch2','TOX21_NFkB_BLA_agonist_ratio','TOX21_NFkB_BLA_agonist_viability','TOX21_PPARd_BLA_Agonist_viability','TOX21_PPARd_BLA_Antagonist_ch1','TOX21_PPARd_BLA_agonist_ch1','TOX21_PPARd_BLA_agonist_ch2','TOX21_PPARd_BLA_agonist_ratio','TOX21_PPARd_BLA_antagonist_ratio','TOX21_PPARd_BLA_antagonist_viability','TOX21_PPARg_BLA_Agonist_ch1','TOX21_PPARg_BLA_Agonist_ch2','TOX21_PPARg_BLA_Agonist_ratio','TOX21_PPARg_BLA_Antagonist_ch1','TOX21_PPARg_BLA_antagonist_ratio','TOX21_PPARg_BLA_antagonist_viability','TOX21_TR_LUC_GH3_Agonist','TOX21_TR_LUC_GH3_Antagonist','TOX21_VDR_BLA_Agonist_viability','TOX21_VDR_BLA_Antagonist_ch1','TOX21_VDR_BLA_agonist_ch2','TOX21_VDR_BLA_agonist_ratio','TOX21_VDR_BLA_antagonist_ratio','TOX21_VDR_BLA_antagonist_viability','TOX21_p53_BLA_p1_ch1','TOX21_p53_BLA_p1_ch2','TOX21_p53_BLA_p1_ratio','TOX21_p53_BLA_p1_viability','TOX21_p53_BLA_p2_ch1','TOX21_p53_BLA_p2_ch2','TOX21_p53_BLA_p2_ratio','TOX21_p53_BLA_p2_viability','TOX21_p53_BLA_p3_ch1','TOX21_p53_BLA_p3_ch2','TOX21_p53_BLA_p3_ratio','TOX21_p53_BLA_p3_viability','TOX21_p53_BLA_p4_ch1','TOX21_p53_BLA_p4_ch2','TOX21_p53_BLA_p4_ratio','TOX21_p53_BLA_p4_viability','TOX21_p53_BLA_p5_ch1','TOX21_p53_BLA_p5_ch2','TOX21_p53_BLA_p5_ratio','TOX21_p53_BLA_p5_viability','Tanguay_ZF_120hpf_AXIS_up','Tanguay_ZF_120hpf_ActivityScore','Tanguay_ZF_120hpf_BRAI_up','Tanguay_ZF_120hpf_CFIN_up','Tanguay_ZF_120hpf_CIRC_up','Tanguay_ZF_120hpf_EYE_up','Tanguay_ZF_120hpf_JAW_up','Tanguay_ZF_120hpf_MORT_up','Tanguay_ZF_120hpf_OTIC_up','Tanguay_ZF_120hpf_PE_up','Tanguay_ZF_120hpf_PFIN_up','Tanguay_ZF_120hpf_PIG_up','Tanguay_ZF_120hpf_SNOU_up','Tanguay_ZF_120hpf_SOMI_up','Tanguay_ZF_120hpf_SWIM_up','Tanguay_ZF_120hpf_TRUN_up','Tanguay_ZF_120hpf_TR_up','Tanguay_ZF_120hpf_YSE_up'\n",
    "]\n",
    "raw_filename = \"../data/toxcast_data.csv\"\n",
    "feature_filename = raw_filename.replace('.csv','.pickle')\n",
    "filename = raw_filename.replace('.csv','')\n",
    "prefix_filename = raw_filename.split('/')[-1].replace('.csv','')\n",
    "smiles_tasks_df = pd.read_csv(raw_filename)\n",
    "smilesList = smiles_tasks_df.smiles.values\n",
    "print(\"number of all smiles: \",len(smilesList))\n",
    "atom_num_dist = []\n",
    "remained_smiles = []\n",
    "canonical_smiles_list = []\n",
    "for smiles in smilesList:\n",
    "    try:        \n",
    "        mol = Chem.MolFromSmiles(smiles)\n",
    "        atom_num_dist.append(len(mol.GetAtoms()))\n",
    "        remained_smiles.append(smiles)\n",
    "        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))\n",
    "    except:\n",
    "        print(\"not successfully processed smiles: \", smiles)\n",
    "        pass\n",
    "print(\"number of successfully processed smiles: \", len(remained_smiles))\n",
    "smiles_tasks_df = smiles_tasks_df[smiles_tasks_df[\"smiles\"].isin(remained_smiles)]\n",
    "# print(smiles_tasks_df)\n",
    "smiles_tasks_df['cano_smiles'] =canonical_smiles_list\n",
    "assert canonical_smiles_list[8]==Chem.MolToSmiles(Chem.MolFromSmiles(smiles_tasks_df['cano_smiles'][8]), isomericSmiles=True)\n",
    "\n",
    "plt.figure(figsize=(5, 3))\n",
    "sns.set(font_scale=1.5)\n",
    "ax = sns.distplot(atom_num_dist, bins=28, kde=False)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"atom_num_dist_\"+prefix_filename+\".png\",dpi=200)\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# print(len([i for i in atom_num_dist if i<51]),len([i for i in atom_num_dist if i>50]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 888\n",
    "start_time = str(time.ctime()).replace(':','-').replace(' ','_')\n",
    "start = time.time()\n",
    "\n",
    "batch_size = 100\n",
    "epochs = 800\n",
    "p_dropout = 0.5\n",
    "fingerprint_dim = 200\n",
    "\n",
    "radius = 3\n",
    "T = 3\n",
    "weight_decay = 3 # also known as l2_regularization_lambda\n",
    "learning_rate = 3.5\n",
    "per_task_output_units_num = 2 # for classification model with 2 classes\n",
    "output_units_num = len(tasks) * per_task_output_units_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Ca+2].[Cl-].[Cl-]\n",
      "[Cd+2].[Cl-].[Cl-]\n",
      "CC[n+]1ccccc1.F[P-](F)(F)(F)(F)F\n",
      "O.O.[Ba+2].[Cl-].[Cl-]\n",
      "CC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F\n",
      "N#C[Fe-2](C#N)(C#N)(C#N)(C#N)N=O.O.O.[Na+].[Na+]\n",
      "CCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F\n",
      "CCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F\n",
      "CCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F\n",
      "CCCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCCCCCCCCCCCC[P+](CCCCCC)(CCCCCC)CCCCCC.F[P-](F)(F)(F)(F)F\n",
      "CCCC[n+]1ccc(C)cc1.F[P-](F)(F)(F)(F)F\n",
      "[Cl-].[Cl-].[Cu+2]\n",
      "[Cl-].[Cl-].[Hg+2]\n",
      "[Br-].[Na+]\n",
      "[I-].[K+]\n",
      "[Cl-].[Cl-].[Cl-].[Fe+3]\n",
      "[Cl-].[Cl-].[Fe+2]\n",
      "[Cl-].[Cl-].[SnH2+2]\n",
      "[Hg+2].[I-].[I-]\n",
      "CCCCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F\n",
      "CCCC[n+]1cccc(C)c1.F[P-](F)(F)(F)(F)F\n",
      "CCCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F\n",
      "CCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F\n",
      "CCCC[n+]1ccccc1C.F[P-](F)(F)(F)(F)F\n",
      "feature dicts file saved as ../data/toxcast_data.pickle\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>ACEA_T47D_80hr_Negative</th>\n",
       "      <th>ACEA_T47D_80hr_Positive</th>\n",
       "      <th>APR_HepG2_CellCycleArrest_24h_dn</th>\n",
       "      <th>APR_HepG2_CellCycleArrest_24h_up</th>\n",
       "      <th>APR_HepG2_CellCycleArrest_72h_dn</th>\n",
       "      <th>APR_HepG2_CellLoss_24h_dn</th>\n",
       "      <th>APR_HepG2_CellLoss_72h_dn</th>\n",
       "      <th>APR_HepG2_MicrotubuleCSK_24h_dn</th>\n",
       "      <th>APR_HepG2_MicrotubuleCSK_24h_up</th>\n",
       "      <th>...</th>\n",
       "      <th>Tanguay_ZF_120hpf_PE_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_PFIN_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_PIG_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_SNOU_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_SOMI_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_SWIM_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_TRUN_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_TR_up</th>\n",
       "      <th>Tanguay_ZF_120hpf_YSE_up</th>\n",
       "      <th>cano_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>[Cl-].[Cl-].[Ca++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Ca+2].[Cl-].[Cl-]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>93</th>\n",
       "      <td>[Cl-].[Cl-].[Cd++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Cd+2].[Cl-].[Cl-]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>213</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CC[N+]1=CC=CC=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CC[n+]1ccccc1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>220</th>\n",
       "      <td>O.O.[Cl-].[Cl-].[Ba++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>O.O.[Ba+2].[Cl-].[Cl-]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1203</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CC[N+]1(C)CCCC1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1777</th>\n",
       "      <td>O.O.[Na+].[Na+].O=N[Fe--](C#N)(C#N)(C#N)(C#N)C#N</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>N#C[Fe-2](C#N)(C#N)(C#N)(C#N)N=O.O.O.[Na+].[Na+]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2259</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2532</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2676</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2983</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3020</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3021</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3022</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCCCN1C=C[N+](C...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3023</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCCCCCN1C=C[N+]...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3107</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCN1C=C[N+](C)=C1C</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3652</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCN1C=C[N+](C)=C1C</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3763</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3764</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3914</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCC[N+]1(C)CCCC1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3948</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1(C)CCCC1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4172</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCN1C=C[N+](C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4264</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCC[P+](CCCCCC)(...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCCCCCCCCCC[P+](CCCCCC)(CCCCCC)CCCCCC.F[P-]...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4408</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=C(C)C=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[n+]1ccc(C)cc1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7018</th>\n",
       "      <td>[Cl-].[Cl-].[Cu++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Cl-].[Cl-].[Cu+2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7034</th>\n",
       "      <td>[Cl-].[Cl-].[Hg++]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>[Cl-].[Cl-].[Hg+2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7162</th>\n",
       "      <td>[Na+].[Br-]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Br-].[Na+]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7188</th>\n",
       "      <td>[K+].[I-]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[I-].[K+]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7211</th>\n",
       "      <td>[Cl-].[Cl-].[Cl-].[Fe+3]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Cl-].[Cl-].[Cl-].[Fe+3]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7255</th>\n",
       "      <td>[Cl-].[Cl-].[Fe++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Cl-].[Cl-].[Fe+2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7262</th>\n",
       "      <td>[Cl-].[Cl-].[SnH2++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Cl-].[Cl-].[SnH2+2]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7266</th>\n",
       "      <td>[I-].[I-].[Hg++]</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>[Hg+2].[I-].[I-]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7435</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCCCC[N+]1=CC=CC=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7696</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC(C)=C1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[n+]1cccc(C)c1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8589</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1(C)CCCCC1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8590</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCC[N+]1(C)CCCCC1</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8593</th>\n",
       "      <td>F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC=C1C</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>CCCC[n+]1ccccc1C.F[P-](F)(F)(F)(F)F</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>36 rows × 619 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 smiles  \\\n",
       "47                                   [Cl-].[Cl-].[Ca++]   \n",
       "93                                   [Cl-].[Cl-].[Cd++]   \n",
       "213                 F[P-](F)(F)(F)(F)F.CC[N+]1=CC=CC=C1   \n",
       "220                              O.O.[Cl-].[Cl-].[Ba++]   \n",
       "1203                 F[P-](F)(F)(F)(F)F.CC[N+]1(C)CCCC1   \n",
       "1777   O.O.[Na+].[Na+].O=N[Fe--](C#N)(C#N)(C#N)(C#N)C#N   \n",
       "2259               F[P-](F)(F)(F)(F)F.CCN1C=C[N+](C)=C1   \n",
       "2532             F[P-](F)(F)(F)(F)F.CCCCN1C=C[N+](C)=C1   \n",
       "2676              F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC=C1   \n",
       "2983              F[P-](F)(F)(F)(F)F.CCCN1C=C[N+](C)=C1   \n",
       "3020     F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCN1C=C[N+](C)=C1   \n",
       "3021   F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCN1C=C[N+](C)=C1   \n",
       "3022  F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCCCN1C=C[N+](C...   \n",
       "3023  F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCCCCCCN1C=C[N+]...   \n",
       "3107            F[P-](F)(F)(F)(F)F.CCCCN1C=C[N+](C)=C1C   \n",
       "3652              F[P-](F)(F)(F)(F)F.CCN1C=C[N+](C)=C1C   \n",
       "3763           F[P-](F)(F)(F)(F)F.CCCCCCN1C=C[N+](C)=C1   \n",
       "3764         F[P-](F)(F)(F)(F)F.CCCCCCCCN1C=C[N+](C)=C1   \n",
       "3914                F[P-](F)(F)(F)(F)F.CCC[N+]1(C)CCCC1   \n",
       "3948               F[P-](F)(F)(F)(F)F.CCCC[N+]1(C)CCCC1   \n",
       "4172       F[P-](F)(F)(F)(F)F.CCCCCCCCCCN1C=C[N+](C)=C1   \n",
       "4264  F[P-](F)(F)(F)(F)F.CCCCCCCCCCCCCC[P+](CCCCCC)(...   \n",
       "4408           F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=C(C)C=C1   \n",
       "7018                                 [Cl-].[Cl-].[Cu++]   \n",
       "7034                                 [Cl-].[Cl-].[Hg++]   \n",
       "7162                                        [Na+].[Br-]   \n",
       "7188                                          [K+].[I-]   \n",
       "7211                           [Cl-].[Cl-].[Cl-].[Fe+3]   \n",
       "7255                                 [Cl-].[Cl-].[Fe++]   \n",
       "7262                               [Cl-].[Cl-].[SnH2++]   \n",
       "7266                                   [I-].[I-].[Hg++]   \n",
       "7435            F[P-](F)(F)(F)(F)F.CCCCCC[N+]1=CC=CC=C1   \n",
       "7696           F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC(C)=C1   \n",
       "8589              F[P-](F)(F)(F)(F)F.CCCC[N+]1(C)CCCCC1   \n",
       "8590               F[P-](F)(F)(F)(F)F.CCC[N+]1(C)CCCCC1   \n",
       "8593             F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC=C1C   \n",
       "\n",
       "      ACEA_T47D_80hr_Negative  ACEA_T47D_80hr_Positive  \\\n",
       "47                        NaN                      NaN   \n",
       "93                        NaN                      NaN   \n",
       "213                       NaN                      NaN   \n",
       "220                       NaN                      NaN   \n",
       "1203                      NaN                      NaN   \n",
       "1777                      NaN                      NaN   \n",
       "2259                      NaN                      NaN   \n",
       "2532                      NaN                      NaN   \n",
       "2676                      NaN                      NaN   \n",
       "2983                      NaN                      NaN   \n",
       "3020                      NaN                      NaN   \n",
       "3021                      NaN                      NaN   \n",
       "3022                      NaN                      NaN   \n",
       "3023                      NaN                      NaN   \n",
       "3107                      NaN                      NaN   \n",
       "3652                      NaN                      NaN   \n",
       "3763                      NaN                      NaN   \n",
       "3764                      NaN                      NaN   \n",
       "3914                      NaN                      NaN   \n",
       "3948                      NaN                      NaN   \n",
       "4172                      NaN                      NaN   \n",
       "4264                      NaN                      NaN   \n",
       "4408                      NaN                      NaN   \n",
       "7018                      NaN                      NaN   \n",
       "7034                      1.0                      0.0   \n",
       "7162                      NaN                      NaN   \n",
       "7188                      NaN                      NaN   \n",
       "7211                      NaN                      NaN   \n",
       "7255                      NaN                      NaN   \n",
       "7262                      NaN                      NaN   \n",
       "7266                      NaN                      NaN   \n",
       "7435                      NaN                      NaN   \n",
       "7696                      NaN                      NaN   \n",
       "8589                      NaN                      NaN   \n",
       "8590                      NaN                      NaN   \n",
       "8593                      NaN                      NaN   \n",
       "\n",
       "      APR_HepG2_CellCycleArrest_24h_dn  APR_HepG2_CellCycleArrest_24h_up  \\\n",
       "47                                 NaN                               NaN   \n",
       "93                                 NaN                               NaN   \n",
       "213                                NaN                               NaN   \n",
       "220                                NaN                               NaN   \n",
       "1203                               NaN                               NaN   \n",
       "1777                               NaN                               NaN   \n",
       "2259                               NaN                               NaN   \n",
       "2532                               NaN                               NaN   \n",
       "2676                               NaN                               NaN   \n",
       "2983                               NaN                               NaN   \n",
       "3020                               NaN                               NaN   \n",
       "3021                               NaN                               NaN   \n",
       "3022                               NaN                               NaN   \n",
       "3023                               NaN                               NaN   \n",
       "3107                               NaN                               NaN   \n",
       "3652                               NaN                               NaN   \n",
       "3763                               NaN                               NaN   \n",
       "3764                               NaN                               NaN   \n",
       "3914                               NaN                               NaN   \n",
       "3948                               NaN                               NaN   \n",
       "4172                               NaN                               NaN   \n",
       "4264                               NaN                               NaN   \n",
       "4408                               NaN                               NaN   \n",
       "7018                               NaN                               NaN   \n",
       "7034                               0.0                               0.0   \n",
       "7162                               NaN                               NaN   \n",
       "7188                               NaN                               NaN   \n",
       "7211                               NaN                               NaN   \n",
       "7255                               NaN                               NaN   \n",
       "7262                               NaN                               NaN   \n",
       "7266                               NaN                               NaN   \n",
       "7435                               NaN                               NaN   \n",
       "7696                               NaN                               NaN   \n",
       "8589                               NaN                               NaN   \n",
       "8590                               NaN                               NaN   \n",
       "8593                               NaN                               NaN   \n",
       "\n",
       "      APR_HepG2_CellCycleArrest_72h_dn  APR_HepG2_CellLoss_24h_dn  \\\n",
       "47                                 NaN                        NaN   \n",
       "93                                 NaN                        NaN   \n",
       "213                                NaN                        NaN   \n",
       "220                                NaN                        NaN   \n",
       "1203                               NaN                        NaN   \n",
       "1777                               NaN                        NaN   \n",
       "2259                               NaN                        NaN   \n",
       "2532                               NaN                        NaN   \n",
       "2676                               NaN                        NaN   \n",
       "2983                               NaN                        NaN   \n",
       "3020                               NaN                        NaN   \n",
       "3021                               NaN                        NaN   \n",
       "3022                               NaN                        NaN   \n",
       "3023                               NaN                        NaN   \n",
       "3107                               NaN                        NaN   \n",
       "3652                               NaN                        NaN   \n",
       "3763                               NaN                        NaN   \n",
       "3764                               NaN                        NaN   \n",
       "3914                               NaN                        NaN   \n",
       "3948                               NaN                        NaN   \n",
       "4172                               NaN                        NaN   \n",
       "4264                               NaN                        NaN   \n",
       "4408                               NaN                        NaN   \n",
       "7018                               NaN                        NaN   \n",
       "7034                               0.0                        1.0   \n",
       "7162                               NaN                        NaN   \n",
       "7188                               NaN                        NaN   \n",
       "7211                               NaN                        NaN   \n",
       "7255                               NaN                        NaN   \n",
       "7262                               NaN                        NaN   \n",
       "7266                               NaN                        NaN   \n",
       "7435                               NaN                        NaN   \n",
       "7696                               NaN                        NaN   \n",
       "8589                               NaN                        NaN   \n",
       "8590                               NaN                        NaN   \n",
       "8593                               NaN                        NaN   \n",
       "\n",
       "      APR_HepG2_CellLoss_72h_dn  APR_HepG2_MicrotubuleCSK_24h_dn  \\\n",
       "47                          NaN                              NaN   \n",
       "93                          NaN                              NaN   \n",
       "213                         NaN                              NaN   \n",
       "220                         NaN                              NaN   \n",
       "1203                        NaN                              NaN   \n",
       "1777                        NaN                              NaN   \n",
       "2259                        NaN                              NaN   \n",
       "2532                        NaN                              NaN   \n",
       "2676                        NaN                              NaN   \n",
       "2983                        NaN                              NaN   \n",
       "3020                        NaN                              NaN   \n",
       "3021                        NaN                              NaN   \n",
       "3022                        NaN                              NaN   \n",
       "3023                        NaN                              NaN   \n",
       "3107                        NaN                              NaN   \n",
       "3652                        NaN                              NaN   \n",
       "3763                        NaN                              NaN   \n",
       "3764                        NaN                              NaN   \n",
       "3914                        NaN                              NaN   \n",
       "3948                        NaN                              NaN   \n",
       "4172                        NaN                              NaN   \n",
       "4264                        NaN                              NaN   \n",
       "4408                        NaN                              NaN   \n",
       "7018                        NaN                              NaN   \n",
       "7034                        1.0                              1.0   \n",
       "7162                        NaN                              NaN   \n",
       "7188                        NaN                              NaN   \n",
       "7211                        NaN                              NaN   \n",
       "7255                        NaN                              NaN   \n",
       "7262                        NaN                              NaN   \n",
       "7266                        NaN                              NaN   \n",
       "7435                        NaN                              NaN   \n",
       "7696                        NaN                              NaN   \n",
       "8589                        NaN                              NaN   \n",
       "8590                        NaN                              NaN   \n",
       "8593                        NaN                              NaN   \n",
       "\n",
       "      APR_HepG2_MicrotubuleCSK_24h_up  ...  Tanguay_ZF_120hpf_PE_up  \\\n",
       "47                                NaN  ...                      NaN   \n",
       "93                                NaN  ...                      NaN   \n",
       "213                               NaN  ...                      NaN   \n",
       "220                               NaN  ...                      NaN   \n",
       "1203                              NaN  ...                      NaN   \n",
       "1777                              NaN  ...                      NaN   \n",
       "2259                              NaN  ...                      NaN   \n",
       "2532                              NaN  ...                      NaN   \n",
       "2676                              NaN  ...                      NaN   \n",
       "2983                              NaN  ...                      NaN   \n",
       "3020                              NaN  ...                      NaN   \n",
       "3021                              NaN  ...                      NaN   \n",
       "3022                              NaN  ...                      NaN   \n",
       "3023                              NaN  ...                      NaN   \n",
       "3107                              NaN  ...                      NaN   \n",
       "3652                              NaN  ...                      NaN   \n",
       "3763                              NaN  ...                      NaN   \n",
       "3764                              NaN  ...                      NaN   \n",
       "3914                              NaN  ...                      NaN   \n",
       "3948                              NaN  ...                      NaN   \n",
       "4172                              NaN  ...                      NaN   \n",
       "4264                              NaN  ...                      NaN   \n",
       "4408                              NaN  ...                      NaN   \n",
       "7018                              NaN  ...                      NaN   \n",
       "7034                              0.0  ...                      1.0   \n",
       "7162                              NaN  ...                      NaN   \n",
       "7188                              NaN  ...                      NaN   \n",
       "7211                              NaN  ...                      NaN   \n",
       "7255                              NaN  ...                      NaN   \n",
       "7262                              NaN  ...                      NaN   \n",
       "7266                              NaN  ...                      NaN   \n",
       "7435                              NaN  ...                      NaN   \n",
       "7696                              NaN  ...                      NaN   \n",
       "8589                              NaN  ...                      NaN   \n",
       "8590                              NaN  ...                      NaN   \n",
       "8593                              NaN  ...                      NaN   \n",
       "\n",
       "      Tanguay_ZF_120hpf_PFIN_up  Tanguay_ZF_120hpf_PIG_up  \\\n",
       "47                          NaN                       NaN   \n",
       "93                          NaN                       NaN   \n",
       "213                         NaN                       NaN   \n",
       "220                         NaN                       NaN   \n",
       "1203                        NaN                       NaN   \n",
       "1777                        NaN                       NaN   \n",
       "2259                        NaN                       NaN   \n",
       "2532                        NaN                       NaN   \n",
       "2676                        NaN                       NaN   \n",
       "2983                        NaN                       NaN   \n",
       "3020                        NaN                       NaN   \n",
       "3021                        NaN                       NaN   \n",
       "3022                        NaN                       NaN   \n",
       "3023                        NaN                       NaN   \n",
       "3107                        NaN                       NaN   \n",
       "3652                        NaN                       NaN   \n",
       "3763                        NaN                       NaN   \n",
       "3764                        NaN                       NaN   \n",
       "3914                        NaN                       NaN   \n",
       "3948                        NaN                       NaN   \n",
       "4172                        NaN                       NaN   \n",
       "4264                        NaN                       NaN   \n",
       "4408                        NaN                       NaN   \n",
       "7018                        NaN                       NaN   \n",
       "7034                        1.0                       1.0   \n",
       "7162                        NaN                       NaN   \n",
       "7188                        NaN                       NaN   \n",
       "7211                        NaN                       NaN   \n",
       "7255                        NaN                       NaN   \n",
       "7262                        NaN                       NaN   \n",
       "7266                        NaN                       NaN   \n",
       "7435                        NaN                       NaN   \n",
       "7696                        NaN                       NaN   \n",
       "8589                        NaN                       NaN   \n",
       "8590                        NaN                       NaN   \n",
       "8593                        NaN                       NaN   \n",
       "\n",
       "      Tanguay_ZF_120hpf_SNOU_up  Tanguay_ZF_120hpf_SOMI_up  \\\n",
       "47                          NaN                        NaN   \n",
       "93                          NaN                        NaN   \n",
       "213                         NaN                        NaN   \n",
       "220                         NaN                        NaN   \n",
       "1203                        NaN                        NaN   \n",
       "1777                        NaN                        NaN   \n",
       "2259                        NaN                        NaN   \n",
       "2532                        NaN                        NaN   \n",
       "2676                        NaN                        NaN   \n",
       "2983                        NaN                        NaN   \n",
       "3020                        NaN                        NaN   \n",
       "3021                        NaN                        NaN   \n",
       "3022                        NaN                        NaN   \n",
       "3023                        NaN                        NaN   \n",
       "3107                        NaN                        NaN   \n",
       "3652                        NaN                        NaN   \n",
       "3763                        NaN                        NaN   \n",
       "3764                        NaN                        NaN   \n",
       "3914                        NaN                        NaN   \n",
       "3948                        NaN                        NaN   \n",
       "4172                        NaN                        NaN   \n",
       "4264                        NaN                        NaN   \n",
       "4408                        NaN                        NaN   \n",
       "7018                        NaN                        NaN   \n",
       "7034                        1.0                        1.0   \n",
       "7162                        NaN                        NaN   \n",
       "7188                        NaN                        NaN   \n",
       "7211                        NaN                        NaN   \n",
       "7255                        NaN                        NaN   \n",
       "7262                        NaN                        NaN   \n",
       "7266                        NaN                        NaN   \n",
       "7435                        NaN                        NaN   \n",
       "7696                        NaN                        NaN   \n",
       "8589                        NaN                        NaN   \n",
       "8590                        NaN                        NaN   \n",
       "8593                        NaN                        NaN   \n",
       "\n",
       "      Tanguay_ZF_120hpf_SWIM_up  Tanguay_ZF_120hpf_TRUN_up  \\\n",
       "47                          NaN                        NaN   \n",
       "93                          NaN                        NaN   \n",
       "213                         NaN                        NaN   \n",
       "220                         NaN                        NaN   \n",
       "1203                        NaN                        NaN   \n",
       "1777                        NaN                        NaN   \n",
       "2259                        NaN                        NaN   \n",
       "2532                        NaN                        NaN   \n",
       "2676                        NaN                        NaN   \n",
       "2983                        NaN                        NaN   \n",
       "3020                        NaN                        NaN   \n",
       "3021                        NaN                        NaN   \n",
       "3022                        NaN                        NaN   \n",
       "3023                        NaN                        NaN   \n",
       "3107                        NaN                        NaN   \n",
       "3652                        NaN                        NaN   \n",
       "3763                        NaN                        NaN   \n",
       "3764                        NaN                        NaN   \n",
       "3914                        NaN                        NaN   \n",
       "3948                        NaN                        NaN   \n",
       "4172                        NaN                        NaN   \n",
       "4264                        NaN                        NaN   \n",
       "4408                        NaN                        NaN   \n",
       "7018                        NaN                        NaN   \n",
       "7034                        1.0                        1.0   \n",
       "7162                        NaN                        NaN   \n",
       "7188                        NaN                        NaN   \n",
       "7211                        NaN                        NaN   \n",
       "7255                        NaN                        NaN   \n",
       "7262                        NaN                        NaN   \n",
       "7266                        NaN                        NaN   \n",
       "7435                        NaN                        NaN   \n",
       "7696                        NaN                        NaN   \n",
       "8589                        NaN                        NaN   \n",
       "8590                        NaN                        NaN   \n",
       "8593                        NaN                        NaN   \n",
       "\n",
       "      Tanguay_ZF_120hpf_TR_up  Tanguay_ZF_120hpf_YSE_up  \\\n",
       "47                        NaN                       NaN   \n",
       "93                        NaN                       NaN   \n",
       "213                       NaN                       NaN   \n",
       "220                       NaN                       NaN   \n",
       "1203                      NaN                       NaN   \n",
       "1777                      NaN                       NaN   \n",
       "2259                      NaN                       NaN   \n",
       "2532                      NaN                       NaN   \n",
       "2676                      NaN                       NaN   \n",
       "2983                      NaN                       NaN   \n",
       "3020                      NaN                       NaN   \n",
       "3021                      NaN                       NaN   \n",
       "3022                      NaN                       NaN   \n",
       "3023                      NaN                       NaN   \n",
       "3107                      NaN                       NaN   \n",
       "3652                      NaN                       NaN   \n",
       "3763                      NaN                       NaN   \n",
       "3764                      NaN                       NaN   \n",
       "3914                      NaN                       NaN   \n",
       "3948                      NaN                       NaN   \n",
       "4172                      NaN                       NaN   \n",
       "4264                      NaN                       NaN   \n",
       "4408                      NaN                       NaN   \n",
       "7018                      NaN                       NaN   \n",
       "7034                      1.0                       1.0   \n",
       "7162                      NaN                       NaN   \n",
       "7188                      NaN                       NaN   \n",
       "7211                      NaN                       NaN   \n",
       "7255                      NaN                       NaN   \n",
       "7262                      NaN                       NaN   \n",
       "7266                      NaN                       NaN   \n",
       "7435                      NaN                       NaN   \n",
       "7696                      NaN                       NaN   \n",
       "8589                      NaN                       NaN   \n",
       "8590                      NaN                       NaN   \n",
       "8593                      NaN                       NaN   \n",
       "\n",
       "                                            cano_smiles  \n",
       "47                                   [Ca+2].[Cl-].[Cl-]  \n",
       "93                                   [Cd+2].[Cl-].[Cl-]  \n",
       "213                    CC[n+]1ccccc1.F[P-](F)(F)(F)(F)F  \n",
       "220                              O.O.[Ba+2].[Cl-].[Cl-]  \n",
       "1203                 CC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F  \n",
       "1777   N#C[Fe-2](C#N)(C#N)(C#N)(C#N)N=O.O.O.[Na+].[Na+]  \n",
       "2259                 CCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "2532               CCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "2676                 CCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F  \n",
       "2983                CCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3020       CCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3021     CCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3022   CCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3023  CCCCCCCCCCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)...  \n",
       "3107              CCCCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F  \n",
       "3652                CCn1cc[n+](C)c1C.F[P-](F)(F)(F)(F)F  \n",
       "3763             CCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3764           CCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "3914                CCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F  \n",
       "3948               CCCC[N+]1(C)CCCC1.F[P-](F)(F)(F)(F)F  \n",
       "4172         CCCCCCCCCCn1cc[n+](C)c1.F[P-](F)(F)(F)(F)F  \n",
       "4264  CCCCCCCCCCCCCC[P+](CCCCCC)(CCCCCC)CCCCCC.F[P-]...  \n",
       "4408              CCCC[n+]1ccc(C)cc1.F[P-](F)(F)(F)(F)F  \n",
       "7018                                 [Cl-].[Cl-].[Cu+2]  \n",
       "7034                                 [Cl-].[Cl-].[Hg+2]  \n",
       "7162                                        [Br-].[Na+]  \n",
       "7188                                          [I-].[K+]  \n",
       "7211                           [Cl-].[Cl-].[Cl-].[Fe+3]  \n",
       "7255                                 [Cl-].[Cl-].[Fe+2]  \n",
       "7262                               [Cl-].[Cl-].[SnH2+2]  \n",
       "7266                                   [Hg+2].[I-].[I-]  \n",
       "7435               CCCCCC[n+]1ccccc1.F[P-](F)(F)(F)(F)F  \n",
       "7696              CCCC[n+]1cccc(C)c1.F[P-](F)(F)(F)(F)F  \n",
       "8589              CCCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F  \n",
       "8590               CCC[N+]1(C)CCCCC1.F[P-](F)(F)(F)(F)F  \n",
       "8593                CCCC[n+]1ccccc1C.F[P-](F)(F)(F)(F)F  \n",
       "\n",
       "[36 rows x 619 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smilesList = [smiles for smiles in canonical_smiles_list if len(Chem.MolFromSmiles(smiles).GetAtoms())<151]\n",
    "uncovered = [smiles for smiles in canonical_smiles_list if len(Chem.MolFromSmiles(smiles).GetAtoms())>150]\n",
    "\n",
    "smiles_tasks_df = smiles_tasks_df[~smiles_tasks_df[\"cano_smiles\"].isin(uncovered)]\n",
    "\n",
    "if os.path.isfile(feature_filename):\n",
    "    feature_dicts = pickle.load(open(feature_filename, \"rb\" ))\n",
    "else:\n",
    "    feature_dicts = save_smiles_dicts(smilesList,filename)\n",
    "# feature_dicts = get_smiles_dicts(smilesList)\n",
    "\n",
    "remained_df = smiles_tasks_df[smiles_tasks_df[\"cano_smiles\"].isin(feature_dicts['smiles_to_atom_mask'].keys())]\n",
    "uncovered_df = smiles_tasks_df.drop(remained_df.index)\n",
    "uncovered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = []\n",
    "positive_dist = []\n",
    "negative_dist = []\n",
    "for i,task in enumerate(tasks):    \n",
    "    negative_df = remained_df[remained_df[task] == 0][[\"cano_smiles\",task]]\n",
    "    positive_df = remained_df[remained_df[task] == 1][[\"cano_smiles\",task]]\n",
    "    try:\n",
    "        weights.append([(positive_df.shape[0]+negative_df.shape[0])/negative_df.shape[0],\\\n",
    "                        (positive_df.shape[0]+negative_df.shape[0])/positive_df.shape[0]])\n",
    "    except:\n",
    "        weights.append([1,1])\n",
    "        \n",
    "#     print(positive_test.shape[0],negative_test.shape[0])\n",
    "\n",
    "    negative_dist.append(negative_df.shape[0])\n",
    "    positive_dist.append(positive_df.shape[0])\n",
    "    if len(negative_df)!=0:\n",
    "        negative_test = negative_df.sample(frac=0.1,random_state=68)\n",
    "        negative_valid = negative_df.drop(negative_test.index).sample(frac=1/9,random_state=68)\n",
    "        negative_train = negative_df.drop(negative_test.index).drop(negative_valid.index)\n",
    "    else:\n",
    "        negative_test = negative_df\n",
    "        negative_valid = negative_df\n",
    "        negative_train = negative_df\n",
    "        \n",
    "    if len(positive_df)!=0:\n",
    "        positive_test = positive_df.sample(frac=0.1,random_state=68)\n",
    "        positive_valid = positive_df.drop(positive_test.index).sample(frac=1/9,random_state=68)\n",
    "        positive_train = positive_df.drop(positive_test.index).drop(positive_valid.index)\n",
    "    else:\n",
    "        positive_test = positive_df\n",
    "        positive_valid = positive_df\n",
    "        positive_train = positive_df\n",
    "    \n",
    "    train_df_new = pd.concat([negative_train,positive_train])\n",
    "    valid_df_new = pd.concat([negative_valid,positive_valid])\n",
    "    test_df_new = pd.concat([negative_test,positive_test])\n",
    "    if i==0:\n",
    "        train_df = train_df_new\n",
    "        test_df = test_df_new\n",
    "        valid_df = valid_df_new\n",
    "    else:\n",
    "        train_df = pd.merge(train_df, train_df_new, on='cano_smiles', how='outer') \n",
    "        test_df = pd.merge(test_df, test_df_new, on='cano_smiles', how='outer')\n",
    "        valid_df = pd.merge(valid_df, valid_df_new, on='cano_smiles', how='outer')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1393238\n",
      "atom_fc.weight torch.Size([200, 39])\n",
      "atom_fc.bias torch.Size([200])\n",
      "neighbor_fc.weight torch.Size([200, 49])\n",
      "neighbor_fc.bias torch.Size([200])\n",
      "GRUCell.0.weight_ih torch.Size([600, 200])\n",
      "GRUCell.0.weight_hh torch.Size([600, 200])\n",
      "GRUCell.0.bias_ih torch.Size([600])\n",
      "GRUCell.0.bias_hh torch.Size([600])\n",
      "GRUCell.1.weight_ih torch.Size([600, 200])\n",
      "GRUCell.1.weight_hh torch.Size([600, 200])\n",
      "GRUCell.1.bias_ih torch.Size([600])\n",
      "GRUCell.1.bias_hh torch.Size([600])\n",
      "GRUCell.2.weight_ih torch.Size([600, 200])\n",
      "GRUCell.2.weight_hh torch.Size([600, 200])\n",
      "GRUCell.2.bias_ih torch.Size([600])\n",
      "GRUCell.2.bias_hh torch.Size([600])\n",
      "align.0.weight torch.Size([1, 400])\n",
      "align.0.bias torch.Size([1])\n",
      "align.1.weight torch.Size([1, 400])\n",
      "align.1.bias torch.Size([1])\n",
      "align.2.weight torch.Size([1, 400])\n",
      "align.2.bias torch.Size([1])\n",
      "attend.0.weight torch.Size([200, 200])\n",
      "attend.0.bias torch.Size([200])\n",
      "attend.1.weight torch.Size([200, 200])\n",
      "attend.1.bias torch.Size([200])\n",
      "attend.2.weight torch.Size([200, 200])\n",
      "attend.2.bias torch.Size([200])\n",
      "mol_GRUCell.weight_ih torch.Size([600, 200])\n",
      "mol_GRUCell.weight_hh torch.Size([600, 200])\n",
      "mol_GRUCell.bias_ih torch.Size([600])\n",
      "mol_GRUCell.bias_hh torch.Size([600])\n",
      "mol_align.weight torch.Size([1, 400])\n",
      "mol_align.bias torch.Size([1])\n",
      "mol_attend.weight torch.Size([200, 200])\n",
      "mol_attend.bias torch.Size([200])\n",
      "output.weight torch.Size([1234, 200])\n",
      "output.bias torch.Size([1234])\n"
     ]
    }
   ],
   "source": [
    "x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([smilesList[0]],feature_dicts)\n",
    "num_atom_features = x_atom.shape[-1]\n",
    "num_bond_features = x_bonds.shape[-1]\n",
    "\n",
    "loss_function = [nn.CrossEntropyLoss(torch.Tensor(weight),reduction='mean') for weight in weights]\n",
    "model = Fingerprint(radius, T, num_atom_features,num_bond_features,\n",
    "            fingerprint_dim, output_units_num, p_dropout)\n",
    "model.cuda()\n",
    "# tensorboard = SummaryWriter(log_dir=\"runs/\"+start_time+\"_\"+prefix_filename+\"_\"+str(fingerprint_dim)+\"_\"+str(p_dropout))\n",
    "\n",
    "# optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)\n",
    "optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)\n",
    "model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "params = sum([np.prod(p.size()) for p in model_parameters])\n",
    "print(params)\n",
    "for name, param in model.named_parameters():\n",
    "    if param.requires_grad:\n",
    "        print(name, param.data.shape)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataset, optimizer, loss_function):\n",
    "    model.train()\n",
    "    np.random.seed(epoch)\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    #shuffle them\n",
    "    np.random.shuffle(valList)\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, train_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[train_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "#         print(torch.Tensor(x_atom).size(),torch.Tensor(x_bonds).size(),torch.cuda.LongTensor(x_atom_index).size(),torch.cuda.LongTensor(x_bond_index).size(),torch.Tensor(x_mask).size())\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss = 0.0\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where(y_val != -1)[0]\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "\n",
    "            loss += loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "        # Step 5. Do the backward pass and update the gradient\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "def eval(model, dataset):\n",
    "    model.eval()\n",
    "    y_val_list = {}\n",
    "    y_pred_list = {}\n",
    "    losses_list = []\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, eval_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[eval_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "        atom_pred = atoms_prediction.data[:,:,1].unsqueeze(2).cpu().numpy()\n",
    "        for i,task in enumerate(tasks):\n",
    "            y_pred = mol_prediction[:, i * per_task_output_units_num:(i + 1) *\n",
    "                                    per_task_output_units_num]\n",
    "            y_val = batch_df[task].values\n",
    "\n",
    "            validInds = np.where((y_val==0) | (y_val==1))[0]\n",
    "#             validInds = np.where((y_val=='0') | (y_val=='1'))[0]\n",
    "#             print(validInds)\n",
    "            if len(validInds) == 0:\n",
    "                continue\n",
    "            y_val_adjust = np.array([y_val[v] for v in validInds]).astype(float)\n",
    "            validInds = torch.cuda.LongTensor(validInds).squeeze()\n",
    "            y_pred_adjust = torch.index_select(y_pred, 0, validInds)\n",
    "#             print(validInds)\n",
    "            loss = loss_function[i](\n",
    "                y_pred_adjust,\n",
    "                torch.cuda.LongTensor(y_val_adjust))\n",
    "#             print(y_pred_adjust)\n",
    "            y_pred_adjust = F.softmax(y_pred_adjust,dim=-1).data.cpu().numpy()[:,1]\n",
    "            losses_list.append(loss.cpu().detach().numpy())\n",
    "            try:\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "            except:\n",
    "                y_val_list[i] = []\n",
    "                y_pred_list[i] = []\n",
    "                y_val_list[i].extend(y_val_adjust)\n",
    "                y_pred_list[i].extend(y_pred_adjust)\n",
    "                \n",
    "    eval_roc = [roc_auc_score(y_val_list[i], y_pred_list[i]) for i in range(len(tasks))]\n",
    "#     eval_prc = [auc(precision_recall_curve(y_val_list[i], y_pred_list[i])[1],precision_recall_curve(y_val_list[i], y_pred_list[i])[0]) for i in range(len(tasks))]\n",
    "#     eval_precision = [precision_score(y_val_list[i],\n",
    "#                                      (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "#     eval_recall = [recall_score(y_val_list[i],\n",
    "#                                (np.array(y_pred_list[i]) > 0.5).astype(int)) for i in range(len(tasks))]\n",
    "    eval_loss = np.array(losses_list).mean()\n",
    "    \n",
    "    return eval_roc, eval_loss #eval_prc, eval_precision, eval_recall, \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t0\n",
      "train_roc_mean:0.5021973610323281\n",
      "valid_roc_mean:0.49762028825717286\n",
      "\n",
      "EPOCH:\t1\n",
      "train_roc_mean:0.5855546866380886\n",
      "valid_roc_mean:0.5667125335950007\n",
      "\n",
      "EPOCH:\t2\n",
      "train_roc_mean:0.6528233351285877\n",
      "valid_roc_mean:0.6404708426779874\n",
      "\n",
      "EPOCH:\t3\n",
      "train_roc_mean:0.6799620377739932\n",
      "valid_roc_mean:0.6636651569995672\n",
      "\n",
      "EPOCH:\t4\n",
      "train_roc_mean:0.6945825654124987\n",
      "valid_roc_mean:0.675848830473782\n",
      "\n",
      "EPOCH:\t5\n",
      "train_roc_mean:0.70698056414664\n",
      "valid_roc_mean:0.6859976700186519\n",
      "\n",
      "EPOCH:\t6\n",
      "train_roc_mean:0.7158189410252798\n",
      "valid_roc_mean:0.6944491124144468\n",
      "\n",
      "EPOCH:\t7\n",
      "train_roc_mean:0.7230352855303822\n",
      "valid_roc_mean:0.7000069595648777\n",
      "\n",
      "EPOCH:\t8\n",
      "train_roc_mean:0.7307899332707634\n",
      "valid_roc_mean:0.7077518046607768\n",
      "\n",
      "EPOCH:\t9\n",
      "train_roc_mean:0.7371558967756076\n",
      "valid_roc_mean:0.7127971737601481\n",
      "\n",
      "EPOCH:\t10\n",
      "train_roc_mean:0.7475103766509096\n",
      "valid_roc_mean:0.720732329574702\n",
      "\n",
      "EPOCH:\t11\n",
      "train_roc_mean:0.753764416268274\n",
      "valid_roc_mean:0.7249094854780693\n",
      "\n",
      "EPOCH:\t12\n",
      "train_roc_mean:0.7648658470580157\n",
      "valid_roc_mean:0.7349851972732697\n",
      "\n",
      "EPOCH:\t13\n",
      "train_roc_mean:0.7740364100636622\n",
      "valid_roc_mean:0.743266880835288\n",
      "\n",
      "EPOCH:\t14\n",
      "train_roc_mean:0.775285538704627\n",
      "valid_roc_mean:0.744101714663671\n",
      "\n",
      "EPOCH:\t15\n",
      "train_roc_mean:0.7883627759787385\n",
      "valid_roc_mean:0.751105079447404\n",
      "\n",
      "EPOCH:\t16\n",
      "train_roc_mean:0.7937706316261142\n",
      "valid_roc_mean:0.7550637421523454\n",
      "\n",
      "EPOCH:\t17\n",
      "train_roc_mean:0.7974452543976002\n",
      "valid_roc_mean:0.7571309289636776\n",
      "\n",
      "EPOCH:\t18\n",
      "train_roc_mean:0.8031655215724527\n",
      "valid_roc_mean:0.7600510693053577\n",
      "\n",
      "EPOCH:\t19\n",
      "train_roc_mean:0.8054155428191994\n",
      "valid_roc_mean:0.7601820396506792\n",
      "\n",
      "EPOCH:\t20\n",
      "train_roc_mean:0.8111294356097765\n",
      "valid_roc_mean:0.766188689414443\n",
      "\n",
      "EPOCH:\t21\n",
      "train_roc_mean:0.8153677840518297\n",
      "valid_roc_mean:0.7666161216562352\n",
      "\n",
      "EPOCH:\t22\n",
      "train_roc_mean:0.82195901951458\n",
      "valid_roc_mean:0.7723248196208551\n",
      "\n",
      "EPOCH:\t23\n",
      "train_roc_mean:0.824259288256087\n",
      "valid_roc_mean:0.7734784099859684\n",
      "\n",
      "EPOCH:\t24\n",
      "train_roc_mean:0.8276109273896993\n",
      "valid_roc_mean:0.7771160145888709\n",
      "\n",
      "EPOCH:\t25\n",
      "train_roc_mean:0.8295823699148538\n",
      "valid_roc_mean:0.7783360235536642\n",
      "\n",
      "EPOCH:\t26\n",
      "train_roc_mean:0.8365074088413823\n",
      "valid_roc_mean:0.7768563337700228\n",
      "\n",
      "EPOCH:\t27\n",
      "train_roc_mean:0.8405083914529996\n",
      "valid_roc_mean:0.7793476065089182\n",
      "\n",
      "EPOCH:\t28\n",
      "train_roc_mean:0.8416235689519305\n",
      "valid_roc_mean:0.7822495191276329\n",
      "\n",
      "EPOCH:\t29\n",
      "train_roc_mean:0.8467529559464417\n",
      "valid_roc_mean:0.7829300910349865\n",
      "\n",
      "EPOCH:\t30\n",
      "train_roc_mean:0.8506280228774326\n",
      "valid_roc_mean:0.7849984681254322\n",
      "\n",
      "EPOCH:\t31\n",
      "train_roc_mean:0.8531948675038085\n",
      "valid_roc_mean:0.7827268585485649\n",
      "\n",
      "EPOCH:\t32\n",
      "train_roc_mean:0.8544294373933664\n",
      "valid_roc_mean:0.7855119004530839\n",
      "\n",
      "EPOCH:\t33\n",
      "train_roc_mean:0.8599199136244058\n",
      "valid_roc_mean:0.7882675320449074\n",
      "\n",
      "EPOCH:\t34\n",
      "train_roc_mean:0.8639010661534787\n",
      "valid_roc_mean:0.7888900848568388\n",
      "\n",
      "EPOCH:\t35\n",
      "train_roc_mean:0.8653172795689202\n",
      "valid_roc_mean:0.7867526482574517\n",
      "\n",
      "EPOCH:\t36\n",
      "train_roc_mean:0.8697130119734973\n",
      "valid_roc_mean:0.7910962359094545\n",
      "\n",
      "EPOCH:\t37\n",
      "train_roc_mean:0.8686545503219268\n",
      "valid_roc_mean:0.7932981001665871\n",
      "\n",
      "EPOCH:\t38\n",
      "train_roc_mean:0.8724200405175445\n",
      "valid_roc_mean:0.7939661809803369\n",
      "\n",
      "EPOCH:\t39\n",
      "train_roc_mean:0.8768845199326892\n",
      "valid_roc_mean:0.7967065008886027\n",
      "\n",
      "EPOCH:\t40\n",
      "train_roc_mean:0.8800583110017821\n",
      "valid_roc_mean:0.7961575311047425\n",
      "\n",
      "EPOCH:\t41\n",
      "train_roc_mean:0.8812368621436176\n",
      "valid_roc_mean:0.7964600841217553\n",
      "\n",
      "EPOCH:\t42\n",
      "train_roc_mean:0.8845075479287338\n",
      "valid_roc_mean:0.8001342803154177\n",
      "\n",
      "EPOCH:\t43\n",
      "train_roc_mean:0.8869185974398677\n",
      "valid_roc_mean:0.7982533039895316\n",
      "\n",
      "EPOCH:\t44\n",
      "train_roc_mean:0.8895428418530877\n",
      "valid_roc_mean:0.8015370717624978\n",
      "\n",
      "EPOCH:\t45\n",
      "train_roc_mean:0.8933243781120072\n",
      "valid_roc_mean:0.8005913119914372\n",
      "\n",
      "EPOCH:\t46\n",
      "train_roc_mean:0.8943808870731337\n",
      "valid_roc_mean:0.8007255140863517\n",
      "\n",
      "EPOCH:\t47\n",
      "train_roc_mean:0.8968867863876007\n",
      "valid_roc_mean:0.8031468754388659\n",
      "\n",
      "EPOCH:\t48\n",
      "train_roc_mean:0.8988480295064198\n",
      "valid_roc_mean:0.8020995047668866\n",
      "\n",
      "EPOCH:\t49\n",
      "train_roc_mean:0.9020147718980508\n",
      "valid_roc_mean:0.8016595950488067\n",
      "\n",
      "EPOCH:\t50\n",
      "train_roc_mean:0.9026529135173086\n",
      "valid_roc_mean:0.8028919963478542\n",
      "\n",
      "EPOCH:\t51\n",
      "train_roc_mean:0.9054954731139132\n",
      "valid_roc_mean:0.8042084257605822\n",
      "\n",
      "EPOCH:\t52\n",
      "train_roc_mean:0.9059701467432237\n",
      "valid_roc_mean:0.8054819531363663\n",
      "\n",
      "EPOCH:\t53\n",
      "train_roc_mean:0.9057981843319812\n",
      "valid_roc_mean:0.8013849561848292\n",
      "\n",
      "EPOCH:\t54\n",
      "train_roc_mean:0.9099411330306943\n",
      "valid_roc_mean:0.8020188736200135\n",
      "\n",
      "EPOCH:\t55\n",
      "train_roc_mean:0.9123533961338752\n",
      "valid_roc_mean:0.807165484518198\n",
      "\n",
      "EPOCH:\t56\n",
      "train_roc_mean:0.9142548248411793\n",
      "valid_roc_mean:0.8089668902227684\n",
      "\n",
      "EPOCH:\t57\n",
      "train_roc_mean:0.9157310358376585\n",
      "valid_roc_mean:0.8064298168962601\n",
      "\n",
      "EPOCH:\t58\n",
      "train_roc_mean:0.9179303980625306\n",
      "valid_roc_mean:0.808168900544742\n",
      "\n",
      "EPOCH:\t59\n",
      "train_roc_mean:0.9186083875126718\n",
      "valid_roc_mean:0.8086814105290404\n",
      "\n",
      "EPOCH:\t60\n",
      "train_roc_mean:0.920162651830914\n",
      "valid_roc_mean:0.8062445935425081\n",
      "\n",
      "EPOCH:\t61\n",
      "train_roc_mean:0.922137220048656\n",
      "valid_roc_mean:0.8091013000029964\n",
      "\n",
      "EPOCH:\t62\n",
      "train_roc_mean:0.9225502056108946\n",
      "valid_roc_mean:0.80737781104246\n",
      "\n",
      "EPOCH:\t63\n",
      "train_roc_mean:0.9247697407838215\n",
      "valid_roc_mean:0.8116682470778592\n",
      "\n",
      "EPOCH:\t64\n",
      "train_roc_mean:0.9251789947349917\n",
      "valid_roc_mean:0.8083177539944812\n",
      "\n",
      "EPOCH:\t65\n",
      "train_roc_mean:0.926427510673274\n",
      "valid_roc_mean:0.8104576208434318\n",
      "\n",
      "EPOCH:\t66\n",
      "train_roc_mean:0.927796306498083\n",
      "valid_roc_mean:0.8100039012946693\n",
      "\n",
      "EPOCH:\t67\n",
      "train_roc_mean:0.9289100826100359\n",
      "valid_roc_mean:0.8090214354202611\n",
      "\n",
      "EPOCH:\t68\n",
      "train_roc_mean:0.9298992831776267\n",
      "valid_roc_mean:0.8090012625648021\n",
      "\n",
      "EPOCH:\t69\n",
      "train_roc_mean:0.9322075690188477\n",
      "valid_roc_mean:0.8132747438059004\n",
      "\n",
      "EPOCH:\t70\n",
      "train_roc_mean:0.9332345063217538\n",
      "valid_roc_mean:0.8109423907534381\n",
      "\n",
      "EPOCH:\t71\n",
      "train_roc_mean:0.9346343414977658\n",
      "valid_roc_mean:0.8139379782149735\n",
      "\n",
      "EPOCH:\t72\n",
      "train_roc_mean:0.9343624876763456\n",
      "valid_roc_mean:0.8142343753916811\n",
      "\n",
      "EPOCH:\t73\n",
      "train_roc_mean:0.935807031880805\n",
      "valid_roc_mean:0.8106347978400394\n",
      "\n",
      "EPOCH:\t74\n",
      "train_roc_mean:0.9374152661646996\n",
      "valid_roc_mean:0.8137938930706992\n",
      "\n",
      "EPOCH:\t75\n",
      "train_roc_mean:0.937979082968863\n",
      "valid_roc_mean:0.8160327880952797\n",
      "\n",
      "EPOCH:\t76\n",
      "train_roc_mean:0.9388000546522278\n",
      "valid_roc_mean:0.8147699434345538\n",
      "\n",
      "EPOCH:\t77\n",
      "train_roc_mean:0.9391934791596065\n",
      "valid_roc_mean:0.8144239958817256\n",
      "\n",
      "EPOCH:\t78\n",
      "train_roc_mean:0.9403785677701074\n",
      "valid_roc_mean:0.8141891017085947\n",
      "\n",
      "EPOCH:\t79\n",
      "train_roc_mean:0.9413444782764251\n",
      "valid_roc_mean:0.8128517970736894\n",
      "\n",
      "EPOCH:\t80\n",
      "train_roc_mean:0.9421801620588913\n",
      "valid_roc_mean:0.8132577023260303\n",
      "\n",
      "EPOCH:\t81\n",
      "train_roc_mean:0.9411804241088375\n",
      "valid_roc_mean:0.8109226604258356\n",
      "\n",
      "EPOCH:\t82\n",
      "train_roc_mean:0.9433869802542411\n",
      "valid_roc_mean:0.8122221847763513\n",
      "\n",
      "EPOCH:\t83\n",
      "train_roc_mean:0.9432334682585886\n",
      "valid_roc_mean:0.8122385598720798\n",
      "\n",
      "EPOCH:\t84\n",
      "train_roc_mean:0.945777703151236\n",
      "valid_roc_mean:0.8144287132015977\n",
      "\n",
      "EPOCH:\t85\n",
      "train_roc_mean:0.9459332426185098\n",
      "valid_roc_mean:0.8142339221948055\n",
      "\n",
      "EPOCH:\t86\n",
      "train_roc_mean:0.9466513088044959\n",
      "valid_roc_mean:0.8149118951243176\n",
      "\n",
      "EPOCH:\t87\n",
      "train_roc_mean:0.946548857554988\n",
      "valid_roc_mean:0.8141452606486774\n",
      "\n",
      "EPOCH:\t88\n",
      "train_roc_mean:0.9486350475537765\n",
      "valid_roc_mean:0.8155985154038287\n",
      "\n",
      "EPOCH:\t89\n",
      "train_roc_mean:0.9477716999217402\n",
      "valid_roc_mean:0.8172636851875312\n",
      "\n",
      "EPOCH:\t90\n",
      "train_roc_mean:0.9478446990288818\n",
      "valid_roc_mean:0.8173855063758152\n",
      "\n",
      "EPOCH:\t91\n",
      "train_roc_mean:0.9485248840644629\n",
      "valid_roc_mean:0.8189705921453038\n",
      "\n",
      "EPOCH:\t92\n",
      "train_roc_mean:0.9497429064288316\n",
      "valid_roc_mean:0.8171622825493969\n",
      "\n",
      "EPOCH:\t93\n",
      "train_roc_mean:0.950597502007526\n",
      "valid_roc_mean:0.8158034347788272\n",
      "\n",
      "EPOCH:\t94\n",
      "train_roc_mean:0.9504900370406507\n",
      "valid_roc_mean:0.8164778994924239\n",
      "\n",
      "EPOCH:\t95\n",
      "train_roc_mean:0.9515228505041696\n",
      "valid_roc_mean:0.8194082832315138\n",
      "\n",
      "EPOCH:\t96\n",
      "train_roc_mean:0.9521145360262682\n",
      "valid_roc_mean:0.8178243540092789\n",
      "\n",
      "EPOCH:\t97\n",
      "train_roc_mean:0.9529856047661116\n",
      "valid_roc_mean:0.8153618010749708\n",
      "\n",
      "EPOCH:\t98\n",
      "train_roc_mean:0.9532833907072195\n",
      "valid_roc_mean:0.8197332474004629\n",
      "\n",
      "EPOCH:\t99\n",
      "train_roc_mean:0.9540542838072175\n",
      "valid_roc_mean:0.8162651200210312\n",
      "\n",
      "EPOCH:\t100\n",
      "train_roc_mean:0.954796597935756\n",
      "valid_roc_mean:0.81832281647803\n",
      "\n",
      "EPOCH:\t101\n",
      "train_roc_mean:0.9548929125249205\n",
      "valid_roc_mean:0.8203404034609577\n",
      "\n",
      "EPOCH:\t102\n",
      "train_roc_mean:0.9545708924327768\n",
      "valid_roc_mean:0.8193414015291903\n",
      "\n",
      "EPOCH:\t103\n",
      "train_roc_mean:0.9561027259890784\n",
      "valid_roc_mean:0.8195832643028993\n",
      "\n",
      "EPOCH:\t104\n",
      "train_roc_mean:0.9559399702822283\n",
      "valid_roc_mean:0.8192853954265016\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t105\n",
      "train_roc_mean:0.9567005901655605\n",
      "valid_roc_mean:0.8186418213703067\n",
      "\n",
      "EPOCH:\t106\n",
      "train_roc_mean:0.9564785717433045\n",
      "valid_roc_mean:0.8188268210001205\n",
      "\n",
      "EPOCH:\t107\n",
      "train_roc_mean:0.9577090056174314\n",
      "valid_roc_mean:0.8189049753411579\n",
      "\n",
      "EPOCH:\t108\n",
      "train_roc_mean:0.9583838461138776\n",
      "valid_roc_mean:0.8203958652786175\n",
      "\n",
      "EPOCH:\t109\n",
      "train_roc_mean:0.9575329688423371\n",
      "valid_roc_mean:0.8207913655163257\n",
      "\n",
      "EPOCH:\t110\n",
      "train_roc_mean:0.9588745200876692\n",
      "valid_roc_mean:0.8197061605302457\n",
      "\n",
      "EPOCH:\t111\n",
      "train_roc_mean:0.9592773232986488\n",
      "valid_roc_mean:0.8213477880870595\n",
      "\n",
      "EPOCH:\t112\n",
      "train_roc_mean:0.9587903501634208\n",
      "valid_roc_mean:0.821885639049331\n",
      "\n",
      "EPOCH:\t113\n",
      "train_roc_mean:0.9594863349463636\n",
      "valid_roc_mean:0.8230710977719151\n",
      "\n",
      "EPOCH:\t114\n",
      "train_roc_mean:0.9601792569477493\n",
      "valid_roc_mean:0.8226274947624065\n",
      "\n",
      "EPOCH:\t115\n",
      "train_roc_mean:0.9608782685101457\n",
      "valid_roc_mean:0.8200285062139449\n",
      "\n",
      "EPOCH:\t116\n",
      "train_roc_mean:0.9609261926639918\n",
      "valid_roc_mean:0.8203072122528914\n",
      "\n",
      "EPOCH:\t117\n",
      "train_roc_mean:0.9614471082522913\n",
      "valid_roc_mean:0.8207445932440669\n",
      "\n",
      "EPOCH:\t118\n",
      "train_roc_mean:0.961481117839826\n",
      "valid_roc_mean:0.8209804190753053\n",
      "\n",
      "EPOCH:\t119\n",
      "train_roc_mean:0.9615356989973971\n",
      "valid_roc_mean:0.8205751696864734\n",
      "\n",
      "EPOCH:\t120\n",
      "train_roc_mean:0.9625208318742493\n",
      "valid_roc_mean:0.8234065808230593\n",
      "\n",
      "EPOCH:\t121\n",
      "train_roc_mean:0.9609674743531332\n",
      "valid_roc_mean:0.8196500440987381\n",
      "\n",
      "EPOCH:\t122\n",
      "train_roc_mean:0.9620155391157241\n",
      "valid_roc_mean:0.8251588534401025\n",
      "\n",
      "EPOCH:\t123\n",
      "train_roc_mean:0.9633101244095708\n",
      "valid_roc_mean:0.8233548326618296\n",
      "\n",
      "EPOCH:\t124\n",
      "train_roc_mean:0.9633838566524601\n",
      "valid_roc_mean:0.8242998306928556\n",
      "\n",
      "EPOCH:\t125\n",
      "train_roc_mean:0.9637384301857047\n",
      "valid_roc_mean:0.8238907411904763\n",
      "\n",
      "EPOCH:\t126\n",
      "train_roc_mean:0.9640667789816644\n",
      "valid_roc_mean:0.824100708047275\n",
      "\n",
      "EPOCH:\t127\n",
      "train_roc_mean:0.964156418990561\n",
      "valid_roc_mean:0.824732710286591\n",
      "\n",
      "EPOCH:\t128\n",
      "train_roc_mean:0.9645706326038084\n",
      "valid_roc_mean:0.8231257543863146\n",
      "\n",
      "EPOCH:\t129\n",
      "train_roc_mean:0.9641658611847876\n",
      "valid_roc_mean:0.8240791818926595\n",
      "\n",
      "EPOCH:\t130\n",
      "train_roc_mean:0.9652824922403156\n",
      "valid_roc_mean:0.8236274223798526\n",
      "\n",
      "EPOCH:\t131\n",
      "train_roc_mean:0.9651687957335555\n",
      "valid_roc_mean:0.8225640677169581\n",
      "\n",
      "EPOCH:\t132\n",
      "train_roc_mean:0.96531530740682\n",
      "valid_roc_mean:0.8237881453190115\n",
      "\n",
      "EPOCH:\t133\n",
      "train_roc_mean:0.9653178793144923\n",
      "valid_roc_mean:0.8254327990123478\n",
      "\n",
      "EPOCH:\t134\n",
      "train_roc_mean:0.9651768292101615\n",
      "valid_roc_mean:0.8250415213576362\n",
      "\n",
      "EPOCH:\t135\n",
      "train_roc_mean:0.9662376318598607\n",
      "valid_roc_mean:0.8243047970122276\n",
      "\n",
      "EPOCH:\t136\n",
      "train_roc_mean:0.9668266011410713\n",
      "valid_roc_mean:0.823414031724269\n",
      "\n",
      "EPOCH:\t137\n",
      "train_roc_mean:0.9662824467931103\n",
      "valid_roc_mean:0.8251339603852699\n",
      "\n",
      "EPOCH:\t138\n",
      "train_roc_mean:0.9669877966712244\n",
      "valid_roc_mean:0.8250670905567739\n",
      "\n",
      "EPOCH:\t139\n",
      "train_roc_mean:0.9670594513776496\n",
      "valid_roc_mean:0.824880724275663\n",
      "\n",
      "EPOCH:\t140\n",
      "train_roc_mean:0.967140020178054\n",
      "valid_roc_mean:0.8254458464335944\n",
      "\n",
      "EPOCH:\t141\n",
      "train_roc_mean:0.9675833156775445\n",
      "valid_roc_mean:0.8243153283712182\n",
      "\n",
      "EPOCH:\t142\n",
      "train_roc_mean:0.9680495848454191\n",
      "valid_roc_mean:0.8247415709221155\n",
      "\n",
      "EPOCH:\t143\n",
      "train_roc_mean:0.9680713661188766\n",
      "valid_roc_mean:0.8234383666148771\n",
      "\n",
      "EPOCH:\t144\n",
      "train_roc_mean:0.9681665916407681\n",
      "valid_roc_mean:0.8238002147660723\n",
      "\n",
      "EPOCH:\t145\n",
      "train_roc_mean:0.9680899800161659\n",
      "valid_roc_mean:0.8250799494519917\n",
      "\n",
      "EPOCH:\t146\n",
      "train_roc_mean:0.9686554920543592\n",
      "valid_roc_mean:0.8265494237048466\n",
      "\n",
      "EPOCH:\t147\n",
      "train_roc_mean:0.9679148310642927\n",
      "valid_roc_mean:0.8255091295655176\n",
      "\n",
      "EPOCH:\t148\n",
      "train_roc_mean:0.9695469758786103\n",
      "valid_roc_mean:0.825504259068079\n",
      "\n",
      "EPOCH:\t149\n",
      "train_roc_mean:0.9691499489434556\n",
      "valid_roc_mean:0.8256350167124564\n",
      "\n",
      "EPOCH:\t150\n",
      "train_roc_mean:0.9696200914468396\n",
      "valid_roc_mean:0.8245997701579288\n",
      "\n",
      "EPOCH:\t151\n",
      "train_roc_mean:0.9695325593584085\n",
      "valid_roc_mean:0.8252895307507169\n",
      "\n",
      "EPOCH:\t152\n",
      "train_roc_mean:0.9699659000507236\n",
      "valid_roc_mean:0.8250818705637333\n",
      "\n",
      "EPOCH:\t153\n",
      "train_roc_mean:0.9694076265404936\n",
      "valid_roc_mean:0.8238623848858492\n",
      "\n",
      "EPOCH:\t154\n",
      "train_roc_mean:0.9691127804955232\n",
      "valid_roc_mean:0.8247957523437608\n",
      "\n",
      "EPOCH:\t155\n",
      "train_roc_mean:0.970212866505037\n",
      "valid_roc_mean:0.8277228566665332\n",
      "\n",
      "EPOCH:\t156\n",
      "train_roc_mean:0.9701488292521316\n",
      "valid_roc_mean:0.8276015171372443\n",
      "\n",
      "EPOCH:\t157\n",
      "train_roc_mean:0.9706041168559084\n",
      "valid_roc_mean:0.8267303088530051\n",
      "\n",
      "EPOCH:\t158\n",
      "train_roc_mean:0.9702697925319038\n",
      "valid_roc_mean:0.8281405004678853\n",
      "\n",
      "EPOCH:\t159\n",
      "train_roc_mean:0.9705638582902467\n",
      "valid_roc_mean:0.8250907682192608\n",
      "\n",
      "EPOCH:\t160\n",
      "train_roc_mean:0.9712086417687247\n",
      "valid_roc_mean:0.8290285172722245\n",
      "\n",
      "EPOCH:\t161\n",
      "train_roc_mean:0.9711948241543678\n",
      "valid_roc_mean:0.8282810678111773\n",
      "\n",
      "EPOCH:\t162\n",
      "train_roc_mean:0.9705421911087652\n",
      "valid_roc_mean:0.8266060264108228\n",
      "\n",
      "EPOCH:\t163\n",
      "train_roc_mean:0.9711207807391131\n",
      "valid_roc_mean:0.8284149770578698\n",
      "\n",
      "EPOCH:\t164\n",
      "train_roc_mean:0.9717166907035639\n",
      "valid_roc_mean:0.8277969591739983\n",
      "\n",
      "EPOCH:\t165\n",
      "train_roc_mean:0.9716477288871896\n",
      "valid_roc_mean:0.8265048780911127\n",
      "\n",
      "EPOCH:\t166\n",
      "train_roc_mean:0.9716171525440165\n",
      "valid_roc_mean:0.826652579450748\n",
      "\n",
      "EPOCH:\t167\n",
      "train_roc_mean:0.9721871092767094\n",
      "valid_roc_mean:0.8272342333170333\n",
      "\n",
      "EPOCH:\t168\n",
      "train_roc_mean:0.9723215654619562\n",
      "valid_roc_mean:0.8292710226291432\n",
      "\n",
      "EPOCH:\t169\n",
      "train_roc_mean:0.9721882566942011\n",
      "valid_roc_mean:0.8277494636860376\n",
      "\n",
      "EPOCH:\t170\n",
      "train_roc_mean:0.972492592962318\n",
      "valid_roc_mean:0.8282000782801816\n",
      "\n",
      "EPOCH:\t171\n",
      "train_roc_mean:0.9724543282982625\n",
      "valid_roc_mean:0.8272662514803607\n",
      "\n",
      "EPOCH:\t172\n",
      "train_roc_mean:0.9727470947104222\n",
      "valid_roc_mean:0.828907788548872\n",
      "\n",
      "EPOCH:\t173\n",
      "train_roc_mean:0.9727078272515518\n",
      "valid_roc_mean:0.8284801777202703\n",
      "\n",
      "EPOCH:\t174\n",
      "train_roc_mean:0.9727796725094661\n",
      "valid_roc_mean:0.8281536659995868\n",
      "\n",
      "EPOCH:\t175\n",
      "train_roc_mean:0.9729574628316329\n",
      "valid_roc_mean:0.8304069959330804\n",
      "\n",
      "EPOCH:\t176\n",
      "train_roc_mean:0.9731734786141626\n",
      "valid_roc_mean:0.8295774652230853\n",
      "\n",
      "EPOCH:\t177\n",
      "train_roc_mean:0.9729991600999635\n",
      "valid_roc_mean:0.8272625034594836\n",
      "\n",
      "EPOCH:\t178\n",
      "train_roc_mean:0.9736342839319222\n",
      "valid_roc_mean:0.8285990507461165\n",
      "\n",
      "EPOCH:\t179\n",
      "train_roc_mean:0.9735964831176966\n",
      "valid_roc_mean:0.8291237254481045\n",
      "\n",
      "EPOCH:\t180\n",
      "train_roc_mean:0.9733998165211111\n",
      "valid_roc_mean:0.8315140291849018\n",
      "\n",
      "EPOCH:\t181\n",
      "train_roc_mean:0.9735421049567266\n",
      "valid_roc_mean:0.8304557481826714\n",
      "\n",
      "EPOCH:\t182\n",
      "train_roc_mean:0.9737536496982041\n",
      "valid_roc_mean:0.8294159947004665\n",
      "\n",
      "EPOCH:\t183\n",
      "train_roc_mean:0.9743468773979722\n",
      "valid_roc_mean:0.8299138218213954\n",
      "\n",
      "EPOCH:\t184\n",
      "train_roc_mean:0.9741749097946756\n",
      "valid_roc_mean:0.8302199851952706\n",
      "\n",
      "EPOCH:\t185\n",
      "train_roc_mean:0.9745284616612191\n",
      "valid_roc_mean:0.8315748449480083\n",
      "\n",
      "EPOCH:\t186\n",
      "train_roc_mean:0.9745999937681643\n",
      "valid_roc_mean:0.8305831215123854\n",
      "\n",
      "EPOCH:\t187\n",
      "train_roc_mean:0.9744518624948099\n",
      "valid_roc_mean:0.8323249822300022\n",
      "\n",
      "EPOCH:\t188\n",
      "train_roc_mean:0.9749316471529358\n",
      "valid_roc_mean:0.8300911339911379\n",
      "\n",
      "EPOCH:\t189\n",
      "train_roc_mean:0.9744772780592843\n",
      "valid_roc_mean:0.8323688677419752\n",
      "\n",
      "EPOCH:\t190\n",
      "train_roc_mean:0.9750638835504327\n",
      "valid_roc_mean:0.8308208315539829\n",
      "\n",
      "EPOCH:\t191\n",
      "train_roc_mean:0.9750383975377244\n",
      "valid_roc_mean:0.8317307260123631\n",
      "\n",
      "EPOCH:\t192\n",
      "train_roc_mean:0.9752011784834114\n",
      "valid_roc_mean:0.8323215680235346\n",
      "\n",
      "EPOCH:\t193\n",
      "train_roc_mean:0.975375028930146\n",
      "valid_roc_mean:0.8307059499637314\n",
      "\n",
      "EPOCH:\t194\n",
      "train_roc_mean:0.9756624804821847\n",
      "valid_roc_mean:0.8323702977584938\n",
      "\n",
      "EPOCH:\t195\n",
      "train_roc_mean:0.9749468587610872\n",
      "valid_roc_mean:0.8305660103300639\n",
      "\n",
      "EPOCH:\t196\n",
      "train_roc_mean:0.975395519192039\n",
      "valid_roc_mean:0.8310003334276822\n",
      "\n",
      "EPOCH:\t197\n",
      "train_roc_mean:0.9757176259098037\n",
      "valid_roc_mean:0.8302686852443661\n",
      "\n",
      "EPOCH:\t198\n",
      "train_roc_mean:0.9758699459021849\n",
      "valid_roc_mean:0.8307994481303321\n",
      "\n",
      "EPOCH:\t199\n",
      "train_roc_mean:0.9756090529323773\n",
      "valid_roc_mean:0.829801025350441\n",
      "\n",
      "EPOCH:\t200\n",
      "train_roc_mean:0.9759965055739793\n",
      "valid_roc_mean:0.8303920926467366\n",
      "\n",
      "EPOCH:\t201\n",
      "train_roc_mean:0.9762879549360329\n",
      "valid_roc_mean:0.831535612891859\n",
      "\n",
      "EPOCH:\t202\n",
      "train_roc_mean:0.9762559340899719\n",
      "valid_roc_mean:0.8322935842609495\n",
      "\n",
      "EPOCH:\t203\n",
      "train_roc_mean:0.9761663573952137\n",
      "valid_roc_mean:0.8332912934310632\n",
      "\n",
      "EPOCH:\t204\n",
      "train_roc_mean:0.9764496305578891\n",
      "valid_roc_mean:0.8327045882138481\n",
      "\n",
      "EPOCH:\t205\n",
      "train_roc_mean:0.9760493896316\n",
      "valid_roc_mean:0.832592147163492\n",
      "\n",
      "EPOCH:\t206\n",
      "train_roc_mean:0.9764177255101986\n",
      "valid_roc_mean:0.8326132555661913\n",
      "\n",
      "EPOCH:\t207\n",
      "train_roc_mean:0.9764832873329865\n",
      "valid_roc_mean:0.8325738745150004\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH:\t208\n",
      "train_roc_mean:0.9767347919259126\n",
      "valid_roc_mean:0.8327087778321205\n",
      "\n",
      "EPOCH:\t209\n",
      "train_roc_mean:0.9763813621997843\n",
      "valid_roc_mean:0.8338218935503655\n",
      "\n",
      "EPOCH:\t210\n",
      "train_roc_mean:0.9770910916310849\n",
      "valid_roc_mean:0.8326132501044023\n",
      "\n",
      "EPOCH:\t211\n",
      "train_roc_mean:0.9771351035464975\n",
      "valid_roc_mean:0.8335247690746799\n",
      "\n",
      "EPOCH:\t212\n",
      "train_roc_mean:0.9772138599021358\n",
      "valid_roc_mean:0.8318545991639386\n",
      "\n",
      "EPOCH:\t213\n",
      "train_roc_mean:0.976735894001854\n",
      "valid_roc_mean:0.832273593459011\n",
      "\n",
      "EPOCH:\t214\n",
      "train_roc_mean:0.9771956973415061\n",
      "valid_roc_mean:0.8344979452263527\n",
      "\n",
      "EPOCH:\t215\n",
      "train_roc_mean:0.9775236800719825\n",
      "valid_roc_mean:0.8316663307629827\n",
      "\n",
      "EPOCH:\t216\n",
      "train_roc_mean:0.9774051406772485\n",
      "valid_roc_mean:0.8316618029330085\n",
      "\n",
      "EPOCH:\t217\n",
      "train_roc_mean:0.977402363782955\n",
      "valid_roc_mean:0.8316159740495572\n",
      "\n",
      "EPOCH:\t218\n",
      "train_roc_mean:0.9775576655922061\n",
      "valid_roc_mean:0.8321549015677306\n",
      "\n",
      "EPOCH:\t219\n",
      "train_roc_mean:0.9776750094446919\n",
      "valid_roc_mean:0.8325246661476069\n",
      "\n",
      "EPOCH:\t220\n",
      "train_roc_mean:0.9772430374305242\n",
      "valid_roc_mean:0.8333534112609334\n",
      "\n",
      "EPOCH:\t221\n",
      "train_roc_mean:0.9779713272737667\n",
      "valid_roc_mean:0.8334997115418488\n",
      "\n",
      "EPOCH:\t222\n",
      "train_roc_mean:0.9779831258879705\n",
      "valid_roc_mean:0.8308500007984341\n",
      "\n",
      "EPOCH:\t223\n",
      "train_roc_mean:0.9778888695283011\n",
      "valid_roc_mean:0.8324726677908018\n",
      "\n",
      "EPOCH:\t224\n",
      "train_roc_mean:0.9778192689688037\n",
      "valid_roc_mean:0.8323137517154713\n",
      "\n",
      "EPOCH:\t225\n",
      "train_roc_mean:0.9780769259881378\n",
      "valid_roc_mean:0.8347324167245438\n",
      "\n",
      "EPOCH:\t226\n",
      "train_roc_mean:0.9780934349358904\n",
      "valid_roc_mean:0.8328417107881141\n",
      "\n",
      "EPOCH:\t227\n",
      "train_roc_mean:0.9780686579049285\n",
      "valid_roc_mean:0.8338096756437384\n",
      "\n",
      "EPOCH:\t228\n",
      "train_roc_mean:0.9782116157308905\n",
      "valid_roc_mean:0.832844900511887\n",
      "\n",
      "EPOCH:\t229\n",
      "train_roc_mean:0.9783572819252013\n",
      "valid_roc_mean:0.8337505571373423\n",
      "\n",
      "EPOCH:\t230\n",
      "train_roc_mean:0.9784393616837085\n",
      "valid_roc_mean:0.8336887226206592\n",
      "\n",
      "EPOCH:\t231\n",
      "train_roc_mean:0.9786710835980135\n",
      "valid_roc_mean:0.8327268726308632\n",
      "\n",
      "EPOCH:\t232\n",
      "train_roc_mean:0.9785234762535817\n",
      "valid_roc_mean:0.835009040453167\n",
      "\n",
      "EPOCH:\t233\n",
      "train_roc_mean:0.9787588621526482\n",
      "valid_roc_mean:0.8347593557758792\n",
      "\n",
      "EPOCH:\t234\n",
      "train_roc_mean:0.9787925149302783\n",
      "valid_roc_mean:0.8357690135985258\n",
      "\n",
      "EPOCH:\t235\n",
      "train_roc_mean:0.9789432191038717\n",
      "valid_roc_mean:0.8321226308656645\n",
      "\n",
      "EPOCH:\t236\n",
      "train_roc_mean:0.9789293039504143\n",
      "valid_roc_mean:0.8327187414802804\n",
      "\n",
      "EPOCH:\t237\n",
      "train_roc_mean:0.9789387277633879\n",
      "valid_roc_mean:0.834420777216625\n",
      "\n",
      "EPOCH:\t238\n",
      "train_roc_mean:0.9790101122527086\n",
      "valid_roc_mean:0.8330616116928541\n",
      "\n",
      "EPOCH:\t239\n",
      "train_roc_mean:0.9793363502651315\n",
      "valid_roc_mean:0.8339947234836647\n",
      "\n",
      "EPOCH:\t240\n",
      "train_roc_mean:0.9794247158127879\n",
      "valid_roc_mean:0.8362553144338392\n",
      "\n",
      "EPOCH:\t241\n",
      "train_roc_mean:0.979444477654621\n",
      "valid_roc_mean:0.8351433747065761\n",
      "\n",
      "EPOCH:\t242\n",
      "train_roc_mean:0.9794441561078904\n",
      "valid_roc_mean:0.8346739316129129\n",
      "\n",
      "EPOCH:\t243\n",
      "train_roc_mean:0.9796230219637548\n",
      "valid_roc_mean:0.8332335629620659\n",
      "\n",
      "EPOCH:\t244\n",
      "train_roc_mean:0.9794418810152603\n",
      "valid_roc_mean:0.8358396221424828\n",
      "\n",
      "EPOCH:\t245\n",
      "train_roc_mean:0.9794524331162417\n",
      "valid_roc_mean:0.8334756852276477\n",
      "\n",
      "EPOCH:\t246\n",
      "train_roc_mean:0.9793578888418634\n",
      "valid_roc_mean:0.8349794737755022\n",
      "\n",
      "EPOCH:\t247\n",
      "train_roc_mean:0.9796005475055151\n",
      "valid_roc_mean:0.8353825472164911\n",
      "\n",
      "EPOCH:\t248\n",
      "train_roc_mean:0.9798868118322672\n",
      "valid_roc_mean:0.8336649090413097\n",
      "\n",
      "EPOCH:\t249\n",
      "train_roc_mean:0.9798921750735752\n",
      "valid_roc_mean:0.8337675097274883\n",
      "\n",
      "EPOCH:\t250\n",
      "train_roc_mean:0.9798509078382772\n",
      "valid_roc_mean:0.8327709348588723\n",
      "\n",
      "EPOCH:\t251\n",
      "train_roc_mean:0.9797863516405187\n",
      "valid_roc_mean:0.8351074878273266\n",
      "\n"
     ]
    }
   ],
   "source": [
    "best_param ={}\n",
    "best_param[\"roc_epoch\"] = 0\n",
    "best_param[\"loss_epoch\"] = 0\n",
    "best_param[\"valid_roc\"] = 0\n",
    "best_param[\"valid_loss\"] = 9e8\n",
    "\n",
    "for epoch in range(epochs):    \n",
    "    train_roc, train_loss = eval(model, train_df)\n",
    "    valid_roc, valid_loss = eval(model, valid_df)\n",
    "    train_roc_mean = np.array(train_roc).mean()\n",
    "    valid_roc_mean = np.array(valid_roc).mean()\n",
    "    \n",
    "#     tensorboard.add_scalars('ROC',{'train_roc':train_roc_mean,'valid_roc':valid_roc_mean},epoch)\n",
    "#     tensorboard.add_scalars('Losses',{'train_losses':train_loss,'valid_losses':valid_loss},epoch)\n",
    "\n",
    "    if valid_roc_mean > best_param[\"valid_roc\"]:\n",
    "        best_param[\"roc_epoch\"] = epoch\n",
    "        best_param[\"valid_roc\"] = valid_roc_mean\n",
    "        if valid_roc_mean > 0.75:\n",
    "             torch.save(model, 'saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(epoch)+'.pt')             \n",
    "    if valid_loss < best_param[\"valid_loss\"]:\n",
    "        best_param[\"loss_epoch\"] = epoch\n",
    "        best_param[\"valid_loss\"] = valid_loss\n",
    "\n",
    "    print(\"EPOCH:\\t\"+str(epoch)+'\\n'\\\n",
    "#         +\"train_roc\"+\":\"+str(train_roc)+'\\n'\\\n",
    "#         +\"valid_roc\"+\":\"+str(valid_roc)+'\\n'\\\n",
    "        +\"train_roc_mean\"+\":\"+str(train_roc_mean)+'\\n'\\\n",
    "        +\"valid_roc_mean\"+\":\"+str(valid_roc_mean)+'\\n'\\\n",
    "        )\n",
    "    if (epoch - best_param[\"roc_epoch\"] >10) and (epoch - best_param[\"loss_epoch\"] >20):        \n",
    "        break\n",
    "        \n",
    "    train(model, train_df, optimizer, loss_function)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The resulting roc list will be very long, because there are 617 tasks in this dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "best epoch:240\n",
      "test_roc:[0.919661733615222, 0.7589743589743588, 0.8669250645994833, 0.7070707070707072, 0.9203586497890296, 0.9785388127853881, 0.9422187981510015, 0.9690721649484536, 0.8144208037825059, 0.8784722222222222, 0.8595959595959597, 0.9401709401709402, 0.94, 0.8816425120772946, 0.8417553191489362, 0.8571428571428572, 0.8185758513931889, 0.7779720279720279, 0.9279379157427938, 0.9213320463320464, 0.8762886597938144, 0.8863636363636362, 0.7425742574257426, 0.9467299578059072, 0.9425641025641026, 0.9285714285714286, 0.9107142857142857, 0.9268817204301075, 0.9529761904761905, 0.9186991869918699, 0.9423076923076923, 0.856, 0.936, 0.7596153846153846, 0.78, 0.923611111111111, 0.9722222222222222, 0.92, 0.78, 0.6057692307692307, 0.47115384615384615, 0.9642857142857143, 1.0, 0.5357142857142857, 1.0, 0.7590361445783133, 0.8952480055497747, 0.9070960387574808, 0.9176557863501484, 0.9351032448377581, 0.75, 0.7467139852786541, 0.7974474474474474, 0.6978851963746224, 0.9044829342842589, 0.90503003003003, 0.7737388724035608, 0.9217999126256008, 0.869402503548845, 0.8149253731343283, 0.9442622950819672, 0.8716216216216216, 0.9256962025316455, 0.9141802846599895, 0.6428571428571428, 0.8425076452599388, 0.8798701298701299, 0.7480400696864111, 0.5276134122287968, 0.9486139283299526, 0.7755681818181819, 0.7746957766642806, 0.755192878338279, 0.7374260355029587, 1.0, 0.8075546922919649, 0.8568691320176278, 0.7905405405405406, 0.8247197156138912, 0.8195266272189349, 0.8075528700906345, 0.7378100940975192, 0.9234388366124893, 0.8651053013798111, 0.6642857142857144, 0.8995098039215687, 0.8672176308539945, 0.8958333333333333, 0.8034124629080119, 0.9401765282772148, 0.868916155419223, 0.7440476190476191, 0.685404339250493, 0.8343283582089552, 0.5783429040196884, 0.8604177279673968, 0.4696745562130178, 0.9785714285714285, 0.901266416510319, 0.7463644140290847, 0.5188492063492063, 0.8811116419812072, 0.952002299291052, 0.6465863453815262, 0.938739325517441, 0.5442477876106194, 0.5233830845771145, 0.6444108761329305, 0.6193293885601577, 0.6306137724550899, 0.9461224700536968, 0.9306784660766961, 0.6460210210210211, 0.5786997433704021, 0.5675074183976261, 0.7373134328358208, 0.9363095238095238, 0.5756218905472636, 0.7731755424063116, 0.9808259587020649, 0.9250253292806485, 0.8178571428571428, 0.7101190476190476, 0.8962079110820529, 0.7995024875621891, 0.9167424931756142, 0.7615244180739389, 0.8237221494102228, 0.8728846153846155, 0.8589123867069487, 0.6656746031746033, 0.8624031007751939, 0.9566074950690334, 0.8614583333333333, 0.7722551928783383, 0.9013350286077559, 0.8629191321499013, 0.9278315503481066, 0.9480712166172106, 0.8228180862250264, 0.9268605645851155, 0.8829561220950399, 0.8309906291834004, 0.8815513626834381, 0.685655253837072, 0.819553534518396, 0.3230088495575221, 0.8877809700861923, 0.9451731761238026, 0.8018987341772152, 0.8975225225225226, 0.8355176933158585, 0.9306784660766961, 0.7981180496150555, 0.9156716417910448, 0.6577909270216962, 0.8989056800416885, 0.7192460317460319, 0.9092702169625246, 0.9541420118343196, 0.8421828908554573, 0.8732782369146006, 0.6108761329305136, 0.7889806079664571, 0.9117210682492581, 0.9009433962264151, 0.8725075528700905, 0.8516320474777448, 0.7297830374753451, 0.877172983756056, 0.8542899408284023, 0.9206349206349207, 0.6270316509837468, 0.8059777102330293, 0.6548323471400394, 0.9505783385909569, 0.8929255711127487, 0.9206231454005935, 0.6137284701114488, 0.8180428134556575, 0.5200296735905044, 0.7997245179063361, 0.5383480825958702, 0.9087322233784254, 0.8849557522123893, 0.9375, 0.8507575757575758, 0.645956607495069, 0.9386882829771555, 0.8970109216324966, 0.8748326639892905, 0.9461538461538461, 0.9364657814096016, 0.9918699186991871, 0.959049959049959, 0.9575471698113207, 0.9806315789473684, 0.9376984126984127, 0.9802517361111112, 0.7089371980676329, 0.796153846153846, 0.9681919642857142, 0.7608695652173914, 0.9888888888888888, 0.9247727272727274, 0.9456115779645191, 0.9811363636363637, 0.9694637537239325, 0.9694711538461539, 0.517605633802817, 0.9372405372405372, 0.962046204620462, 0.9146795827123696, 0.963826998689384, 0.954225352112676, 0.9553113553113554, 0.9569484013928458, 0.8992055610724925, 0.9903157894736841, 0.904, 0.7621843805049912, 0.9564428312159711, 0.967479674796748, 0.9284944304707151, 0.9075630252100841, 0.9100806451612902, 0.8571428571428571, 0.9420138888888889, 0.9400299850074962, 0.8586914440572976, 0.8258992805755396, 0.8959025470653378, 0.91650390625, 0.8915441176470589, 0.9305230288836845, 0.4278959810874704, 0.9165023011176859, 0.6408450704225352, 0.94189453125, 0.9385227272727272, 0.6784869976359338, 0.9625220458553793, 0.7571428571428571, 0.986890756302521, 0.8739495798319328, 0.8349579831932772, 0.9359165424739195, 0.9590062111801242, 0.9357142857142856, 0.9532773109243698, 0.6056338028169015, 0.9651455264103486, 0.9064327485380117, 0.9553571428571428, 0.6595744680851064, 0.9247066492829203, 0.6134751773049645, 0.9228327228327228, 0.9952898550724638, 0.9322444041137325, 0.9470720720720722, 0.9696106362773028, 0.9584500466853407, 0.9685983442763345, 0.7887323943661972, 0.9842991721381673, 0.897887323943662, 0.9188058035714286, 0.718676122931442, 0.9543918918918919, 0.9514423076923076, 0.8523391812865497, 0.7850746268656718, 0.9850972467794897, 0.8756805807622504, 0.776654411764706, 0.9291666666666667, 0.861111111111111, 0.9436363636363637, 0.9654545454545455, 0.9747727272727272, 0.9682692307692308, 0.9585602652143026, 0.8834532374100719, 0.9559633027522936, 0.6607142857142857, 0.9575360419397116, 0.9827524254401725, 0.9757254464285714, 0.7784926470588234, 0.9263392857142857, 0.9531893004115226, 0.9395454545454545, 0.8880597014925373, 0.786231884057971, 0.9898947368421053, 0.9747899159663864, 0.9736227824463118, 0.9707602339181287, 0.9789612097304405, 0.8864734299516908, 0.9721153846153846, 0.9453186467348544, 0.9844389844389845, 0.9148464704020259, 0.938669192517168, 0.8250000000000001, 0.9358552631578947, 0.8108108108108107, 0.7841880341880343, 0.5327380952380952, 0.7391304347826086, 0.7529761904761905, 0.8488888888888889, 0.8305647840531561, 0.8455284552845529, 0.8609523809523809, 0.7543859649122807, 0.8261904761904761, 0.8333333333333333, 0.8214285714285714, 0.625, 0.9583333333333334, 0.7222222222222222, 0.885, 0.8642533936651584, 0.7866666666666666, 0.8863636363636362, 0.9185520361990951, 0.9330357142857143, 0.9166666666666666, 0.9090909090909092, 0.7314814814814815, 0.84, 0.7283950617283951, 0.5555555555555556, 0.7788461538461539, 0.792, 0.9192546583850932, 0.9411764705882353, 0.7074074074074074, 0.8690476190476191, 0.5967741935483871, 0.6666666666666666, 0.8835227272727273, 0.8487394957983193, 0.7, 0.5384615384615384, 1.0, 0.975, 0.6916666666666667, 0.736842105263158, 0.71875, 0.71875, 0.4666666666666667, 1.0, 0.8163265306122449, 0.9444444444444444, 0.5714285714285714, 0.625, 0.6875, 0.5, 0.41666666666666663, 0.65625, 0.8, 0.5238095238095238, 0.9500000000000001, 0.8333333333333334, 0.8125, 0.5833333333333333, 0.9047619047619048, 0.2142857142857143, 0.45833333333333326, 0.7333333333333333, 0.8333333333333333, 0.6363636363636364, 1.0, 0.7222222222222223, 0.5833333333333333, 0.6, 0.5, 1.0, 0.75, 0.2962962962962963, 1.0, 0.625, 0.6923076923076923, 0.5, 0.6666666666666667, 0.6799999999999999, 0.625, 0.6666666666666667, 1.0, 0.5952380952380952, 0.7916666666666666, 0.92, 0.5, 0.6666666666666666, 1.0, 0.625, 0.44999999999999996, 0.8333333333333333, 0.9583333333333334, 0.8222222222222223, 1.0, 0.8095238095238095, 0.875, 0.6041666666666666, 1.0, 0.9583333333333334, 0.8333333333333333, 1.0, 0.7083333333333333, 0.8333333333333334, 1.0, 0.9166666666666667, 0.6333333333333333, 0.6857142857142857, 0.9166666666666667, 0.888888888888889, 1.0, 0.9444444444444445, 0.6071428571428572, 0.9047619047619049, 0.9500000000000001, 0.703125, 0.9333333333333333, 0.7727272727272727, 0.6111111111111112, 0.9285714285714286, 1.0, 0.8, 0.8, 0.75, 0.8, 0.875, 0.7857142857142857, 0.888888888888889, 0.6428571428571429, 0.7777777777777778, 0.5833333333333333, 0.6227106227106227, 0.0, 1.0, 0.8333333333333334, 0.85, 0.8571428571428571, 0.8, 0.7708333333333333, 1.0, 0.75, 0.5, 0.125, 0.39999999999999997, 0.6748251748251748, 0.5265151515151515, 0.6939799331103679, 0.875, 0.69, 0.8787878787878788, 0.7763157894736843, 0.8558866634098681, 1.0, 0.7916666666666666, 0.7045454545454545, 0.7476190476190476, 0.6842592592592592, 0.6900584795321638, 0.7613636363636365, 0.9420289855072463, 0.9642857142857143, 0.8823529411764706, 0.7887624466571835, 0.9395604395604396, 0.9375, 0.5454545454545454, 0.7, 0.6000000000000001, 0.9263157894736843, 0.8571428571428571, 0.875, 0.8125, 0.7666666666666666, 0.92, 0.6503067484662577, 0.8924897119341564, 0.9414160401002507, 0.9632107023411371, 0.8523451071221773, 0.9394923258559623, 0.8819727891156462, 0.878169449598021, 0.8124533929903058, 0.8525506638714185, 0.7681159420289855, 0.8831426056338028, 0.8228200371057512, 0.8674193548387096, 0.5656565656565656, 0.9777859597447226, 0.8908851884312008, 0.8437725225225224, 0.5830972615675165, 0.8017789072426937, 0.9063129502543618, 0.9123384024227426, 0.954140127388535, 0.941742522756827, 0.9302015468434488, 0.9417692129092176, 0.7792417128039808, 0.9436631593147173, 0.9156784295665189, 0.880722283167041, 0.8940601881778352, 0.8660714285714286, 0.9971938775510204, 0.7114795918367347, 0.9974587039390089, 0.9768152866242039, 0.9993646759847522, 0.884766214177979, 0.9157017543859649, 0.8606855473783683, 0.8984565161364222, 0.9435087719298245, 0.9853403141361257, 0.9216542512201387, 0.9905844155844156, 0.6370481927710843, 0.9401947148817802, 0.9847340258105131, 0.9646177685950412, 0.9598329334399716, 0.9555164614224679, 0.9555719194894452, 0.9576315242254108, 0.9259524948190658, 0.9823741007194244, 0.9454499151103565, 0.9881809295299032, 0.9863496257155437, 0.8580579332237401, 0.965081081081081, 0.9669811320754716, 0.9600280504908837, 0.9890666666666666, 0.9691866894541068, 0.8531451805788088, 0.9253435114503817, 0.9875273891791674, 0.9246415650542678, 0.9186197916666666, 0.9140350877192982, 0.9468181361214436, 0.824113475177305, 0.9582730342498793, 0.8848055818994716, 0.9621301775147928, 0.6965386439070649, 0.4726507713884993, 0.905248033877798, 0.9754234724742892, 0.9297520661157025, 0.9573970037453184, 0.9739517153748412, 0.937677375411511, 0.8776430466085637, 0.5346441947565543, 0.9411019283746557, 0.9936607142857143, 0.7646310432569975, 0.9226244765006981, 0.9399298522946203, 0.9742120343839541, 0.9534555712270805, 0.9747780468119451, 0.978683267308055, 0.9905254091300602, 0.9695722081465601, 0.9480988702185009, 0.9715997144299918, 0.9761343849226671, 0.9248512681348502, 0.8461443661971831, 0.8800839112937753, 0.8099307447168367, 0.8851076567370616, 0.8877549021528335, 0.8430123985917649, 0.9248155412347062, 0.9388440288992775, 0.8478213977566867, 0.8979335680391857, 0.8950269399888524, 0.9722243273967413, 0.860646551724138, 0.8937937062937062, 0.9672551138865538, 0.973026973026973, 0.9351190476190476, 0.9461805555555556, 0.9216494845360825, 0.975, 0.949468085106383, 0.986013986013986, 0.8836996336996337, 0.9795918367346939, 0.937037037037037, 0.953125, 0.9855670103092784, 0.962037962037962, 1.0, 0.9562289562289562, 0.9930555555555556, 0.8457446808510638, 0.9161624891961971]\n",
      "test_roc_mean: 0.8390568534999393\n"
     ]
    }
   ],
   "source": [
    "# evaluate model\n",
    "best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param[\"roc_epoch\"])+'.pt')     \n",
    "\n",
    "best_model_dict = best_model.state_dict()\n",
    "best_model_wts = copy.deepcopy(best_model_dict)\n",
    "\n",
    "model.load_state_dict(best_model_wts)\n",
    "(best_model.align[0].weight == model.align[0].weight).all()\n",
    "test_roc, test_losses = eval(model, test_df)\n",
    "\n",
    "print(\"best epoch:\"+str(best_param[\"roc_epoch\"])\n",
    "      +\"\\n\"+\"test_roc:\"+str(test_roc)\n",
    "      +\"\\n\"+\"test_roc_mean:\",str(np.array(test_roc).mean())\n",
    "     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
